commit 9b7ff403025ffafe6682c3641f9ea86ea261c10c parent 1d925136ffbac5559ebce5da21ed7d91980e859c Author: rhunk <101876869+rhunk@users.noreply.github.com> Date: Mon, 28 Aug 2023 02:06:58 +0200 feat: better proto utils - new proto editor Diffstat:
11 files changed, 256 insertions(+), 121 deletions(-)
diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/data/MessageSender.kt b/core/src/main/kotlin/me/rhunk/snapenhance/data/MessageSender.kt @@ -14,18 +14,18 @@ class MessageSender( companion object { val redSnapProto: (Boolean) -> ByteArray = {hasAudio -> ProtoWriter().apply { - write(11, 5) { - write(1) { - write(1) { - writeConstant(2, 0) - writeConstant(12, 0) - writeConstant(15, 0) + from(11, 5) { + from(1) { + from(1) { + addVarInt(2, 0) + addVarInt(12, 0) + addVarInt(15, 0) } - writeConstant(6, 0) + addVarInt(6, 0) } - write(2) { - writeConstant(5, if (hasAudio) 1 else 0) - writeBuffer(6, byteArrayOf()) + from(2) { + addVarInt(5, if (hasAudio) 1 else 0) + addBuffer(6, byteArrayOf()) } } }.toByteArray() @@ -33,15 +33,15 @@ class MessageSender( val audioNoteProto: (Long) -> ByteArray = { duration -> ProtoWriter().apply { - write(6, 1) { - write(1) { - writeConstant(2, 4) - write(5) { - writeConstant(1, 0) - writeConstant(2, 0) + from(6, 1) { + from(1) { + addVarInt(2, 4) + from(5) { + addVarInt(1, 0) + addVarInt(2, 0) } - writeConstant(7, 0) - writeConstant(13, duration) + addVarInt(7, 0) + addVarInt(13, duration) } } }.toByteArray() @@ -153,8 +153,8 @@ class MessageSender( fun sendChatMessage(conversations: List<SnapUUID>, message: String, onError: (Any) -> Unit = {}, onSuccess: () -> Unit = {}) { internalSendMessage(conversations, createLocalMessageContentTemplate(ContentType.CHAT, ProtoWriter().apply { - write(2) { - writeString(1, message) + from(2) { + addString(1, message) } }.toByteArray(), savePolicy = "LIFETIME"), CallbackBuilder(sendMessageCallback) .override("onSuccess", callback = { onSuccess() }) diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/MediaDownloader.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/MediaDownloader.kt @@ -466,7 +466,7 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp val messageReader = ProtoReader(messageContent) val urlProto: ByteArray = if (isArroyoMessage) { var finalProto: ByteArray? = null - messageReader.readPath(4)?.each(5) { + messageReader.eachBuffer(4, 5) { finalProto = getByteArray(1, 3) } finalProto!! diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/ProfilePictureDownloader.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/ProfilePictureDownloader.kt @@ -64,9 +64,9 @@ class ProfilePictureDownloader : Feature("ProfilePictureDownloader", loadParams } } - ProtoReader(content).readPath(1, 1, 2) { - friendUsername = getString(2) ?: return@readPath - readPath(4) { + ProtoReader(content).followPath(1, 1, 2) { + friendUsername = getString(2) ?: return@followPath + followPath(4) { backgroundUrl = getString(2) avatarUrl = getString(100) } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/GalleryMediaSendOverride.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/GalleryMediaSendOverride.kt @@ -28,7 +28,7 @@ class GalleryMediaSendOverride : Feature("Gallery Media Send Override", loadPara //prevent story replies val messageProtoReader = ProtoReader(localMessageContent.content) - if (messageProtoReader.exists(7)) return@subscribe + if (messageProtoReader.contains(7)) return@subscribe event.canceled = true @@ -38,7 +38,7 @@ class GalleryMediaSendOverride : Feature("Gallery Media Send Override", loadPara dialog.dismiss() val overrideType = typeNames.keys.toTypedArray()[which] - if (overrideType != "ORIGINAL" && messageProtoReader.readPath(3)?.getCount(3) != 1) { + if (overrideType != "ORIGINAL" && messageProtoReader.followPath(3)?.getCount(3) != 1) { context.runOnUiThread { ViewAppearanceHelper.newAlertDialogBuilder(context.mainActivity!!) .setMessage(context.translation["gallery_media_send_override.multiple_media_toast"]) @@ -57,7 +57,7 @@ class GalleryMediaSendOverride : Feature("Gallery Media Send Override", loadPara "NOTE" -> { localMessageContent.contentType = ContentType.NOTE val mediaDuration = - messageProtoReader.getLong(3, 3, 5, 1, 1, 15) ?: 0 + messageProtoReader.getVarInt(3, 3, 5, 1, 1, 15) ?: 0 localMessageContent.content = MessageSender.audioNoteProto(mediaDuration) } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/UnlimitedSnapViewTime.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/UnlimitedSnapViewTime.kt @@ -20,12 +20,12 @@ class UnlimitedSnapViewTime : if (message.messageContent.contentType != ContentType.SNAP) return@hookConstructor with(message.messageContent) { - val mediaAttributes = ProtoReader(this.content).readPath(11, 5, 2) ?: return@hookConstructor - if (mediaAttributes.exists(6)) return@hookConstructor + val mediaAttributes = ProtoReader(this.content).followPath(11, 5, 2) ?: return@hookConstructor + if (mediaAttributes.contains(6)) return@hookConstructor this.content = ProtoEditor(this.content).apply { edit(11, 5, 2) { - mediaAttributes.getInt(5)?.let { writeConstant(5, it) } - writeBuffer(6, byteArrayOf()) + remove(8) + addBuffer(6, byteArrayOf()) } }.toByteArray() } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/ProtoEditor.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/ProtoEditor.kt @@ -1,39 +1,64 @@ package me.rhunk.snapenhance.util.protobuf + +typealias WireCallback = EditorContext.() -> Unit + +class EditorContext( + private val wires: MutableMap<Int, MutableList<Wire>> +) { + fun clear() { + wires.clear() + } + fun addWire(wire: Wire) { + wires.getOrPut(wire.id) { mutableListOf() }.add(wire) + } + fun addVarInt(id: Int, value: Int) = addVarInt(id, value.toLong()) + fun addVarInt(id: Int, value: Long) = addWire(Wire(id, WireType.VARINT, value)) + fun addBuffer(id: Int, value: ByteArray) = addWire(Wire(id, WireType.LENGTH_DELIMITED, value)) + fun addString(id: Int, value: String) = addBuffer(id, value.toByteArray()) + fun addFixed64(id: Int, value: Long) = addWire(Wire(id, WireType.FIXED64, value)) + fun addFixed32(id: Int, value: Int) = addWire(Wire(id, WireType.FIXED32, value)) + + fun firstOrNull(id: Int) = wires[id]?.firstOrNull() + fun getOrNull(id: Int) = wires[id] + fun get(id: Int) = wires[id]!! + + fun remove(id: Int) = wires.remove(id) + fun remove(id: Int, index: Int) = wires[id]?.removeAt(index) +} + class ProtoEditor( private var buffer: ByteArray ) { - fun edit(vararg path: Int, callback: ProtoWriter.() -> Unit) { - val writer = ProtoWriter() - callback(writer) - buffer = writeAtPath(path, 0, ProtoReader(buffer), writer.toByteArray()) + fun edit(vararg path: Int, callback: WireCallback) { + buffer = writeAtPath(path, 0, ProtoReader(buffer), callback) } - private fun writeAtPath(path: IntArray, currentIndex: Int, rootReader: ProtoReader, bufferToWrite: ByteArray): ByteArray { - if (currentIndex == path.size) { - return bufferToWrite - } - val id = path[currentIndex] + private fun writeAtPath(path: IntArray, currentIndex: Int, rootReader: ProtoReader, wireToWriteCallback: WireCallback): ByteArray { + val id = path.getOrNull(currentIndex) val output = ProtoWriter() - val wires = mutableListOf<Pair<Int, ByteArray>>() + val wires = mutableMapOf<Int, MutableList<Wire>>() - rootReader.list { tag, value -> - if (tag == id) { - val childReader = rootReader.readPath(id) + rootReader.forEach { wireId, value -> + wires.putIfAbsent(wireId, mutableListOf()) + if (id != null && wireId == id) { + val childReader = rootReader.followPath(id) if (childReader == null) { - wires.add(Pair(tag, value)) - return@list + wires.getOrPut(wireId) { mutableListOf() }.add(value) + return@forEach } - wires.add(Pair(tag, writeAtPath(path, currentIndex + 1, childReader, bufferToWrite))) - return@list + wires[wireId]!!.add(Wire(wireId, WireType.LENGTH_DELIMITED, writeAtPath(path, currentIndex + 1, childReader, wireToWriteCallback))) + return@forEach } - wires.add(Pair(tag, value)) + wires[wireId]!!.add(value) } - wires.forEach { (tag, value) -> - output.writeBuffer(tag, value) + if (currentIndex == path.size) { + wireToWriteCallback(EditorContext(wires)) } + wires.values.flatten().forEach(output::addWire) + return output.toByteArray() } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/ProtoReader.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/ProtoReader.kt @@ -1,6 +1,6 @@ package me.rhunk.snapenhance.util.protobuf -data class Wire(val type: Int, val value: Any) +data class Wire(val id: Int, val type: WireType, val value: Any) class ProtoReader(private val buffer: ByteArray) { private var offset: Int = 0 @@ -32,19 +32,46 @@ class ProtoReader(private val buffer: ByteArray) { while (offset < buffer.size) { val tag = readVarInt().toInt() val id = tag ushr 3 - val type = tag and 0x7 + val type = WireType.fromValue(tag and 0x7) ?: break try { val value = when (type) { - 0 -> readVarInt().toString().toByteArray() - 2 -> { + WireType.VARINT -> readVarInt() + WireType.FIXED64 -> { + val bytes = ByteArray(8) + for (i in 0..7) { + bytes[i] = readByte() + } + bytes + } + WireType.LENGTH_DELIMITED -> { val length = readVarInt().toInt() - val value = buffer.copyOfRange(offset, offset + length) - offset += length - value + val bytes = ByteArray(length) + for (i in 0 until length) { + bytes[i] = readByte() + } + bytes + } + WireType.START_GROUP -> { + val bytes = mutableListOf<Byte>() + while (true) { + val b = readByte() + if (b.toInt() == WireType.END_GROUP.value) { + break + } + bytes.add(b) + } + bytes.toByteArray() } - else -> break + WireType.FIXED32 -> { + val bytes = ByteArray(4) + for (i in 0..3) { + bytes[i] = readByte() + } + bytes + } + WireType.END_GROUP -> continue } - values.getOrPut(id) { mutableListOf() }.add(Wire(type, value)) + values.getOrPut(id) { mutableListOf() }.add(Wire(id, type, value)) } catch (t: Throwable) { values.clear() break @@ -52,13 +79,19 @@ class ProtoReader(private val buffer: ByteArray) { } } - fun readPath(vararg ids: Int, reader: (ProtoReader.() -> Unit)? = null): ProtoReader? { + fun followPath(vararg ids: Int, excludeLast: Boolean = false, reader: (ProtoReader.() -> Unit)? = null): ProtoReader? { var thisReader = this - ids.forEach { id -> - if (!thisReader.exists(id)) { + ids.let { + if (excludeLast) { + it.sliceArray(0 until it.size - 1) + } else { + it + } + }.forEach { id -> + if (!thisReader.contains(id)) { return null } - thisReader = ProtoReader(thisReader.get(id) as ByteArray) + thisReader = ProtoReader(thisReader.getByteArray(id) ?: return null) } if (reader != null) { thisReader.reader() @@ -66,65 +99,77 @@ class ProtoReader(private val buffer: ByteArray) { return thisReader } - fun pathExists(vararg ids: Int): Boolean { + fun containsPath(vararg ids: Int): Boolean { var thisReader = this ids.forEach { id -> - if (!thisReader.exists(id)) { + if (!thisReader.contains(id)) { return false } - thisReader = ProtoReader(thisReader.get(id) as ByteArray) + thisReader = ProtoReader(thisReader.getByteArray(id) ?: return false) } return true } - fun getByteArray(id: Int) = values[id]?.first()?.value as ByteArray? - fun getByteArray(vararg ids: Int): ByteArray? { - if (ids.isEmpty() || ids.size < 2) { - return null - } - val lastId = ids.last() - var value: ByteArray? = null - readPath(*(ids.copyOfRange(0, ids.size - 1))) { - value = getByteArray(lastId) + fun forEach(reader: (Int, Wire) -> Unit) { + values.forEach { (id, wires) -> + wires.forEach { wire -> + reader(id, wire) + } } - return value } - fun getString(id: Int) = getByteArray(id)?.toString(Charsets.UTF_8) - fun getString(vararg ids: Int) = getByteArray(*ids)?.toString(Charsets.UTF_8) - - fun getInt(id: Int) = getString(id)?.toInt() - fun getInt(vararg ids: Int) = getString(*ids)?.toInt() - - fun getLong(id: Int) = getString(id)?.toLong() - fun getLong(vararg ids: Int) = getString(*ids)?.toLong() - - fun exists(id: Int) = values.containsKey(id) - - fun get(id: Int) = values[id]!!.first().value - - fun isValid() = values.isNotEmpty() - - fun getCount(id: Int) = values[id]!!.size + fun forEach(vararg id: Int, reader: ProtoReader.() -> Unit) { + followPath(*id)?.eachBuffer { _, buffer -> + ProtoReader(buffer).reader() + } + } - fun each(id: Int, reader: ProtoReader.(index: Int) -> Unit) { - values[id]!!.forEachIndexed { index, _ -> - ProtoReader(values[id]!![index].value as ByteArray).reader(index) + fun eachBuffer(vararg ids: Int, reader: ProtoReader.() -> Unit) { + followPath(*ids, excludeLast = true)?.eachBuffer { id, buffer -> + if (id == ids.last()) { + ProtoReader(buffer).reader() + } } } - fun list(reader: (id: Int, data: ByteArray) -> Unit) { + fun eachBuffer(reader: (Int, ByteArray) -> Unit) { values.forEach { (id, wires) -> - wires.forEachIndexed { index, _ -> - reader(id, wires[index].value as ByteArray) + wires.forEach { wire -> + if (wire.type == WireType.LENGTH_DELIMITED) { + reader(id, wire.value as ByteArray) + } } } } - fun eachExists(id: Int, reader: ProtoReader.(index: Int) -> Unit) { - if (!exists(id)) { - return + fun contains(id: Int) = values.containsKey(id) + + fun getWire(id: Int) = values[id]?.firstOrNull() + fun getRawValue(id: Int) = getWire(id)?.value + fun getByteArray(id: Int) = getRawValue(id) as? ByteArray + fun getByteArray(vararg ids: Int) = followPath(*ids, excludeLast = true)?.getByteArray(ids.last()) + fun getString(id: Int) = getByteArray(id)?.toString(Charsets.UTF_8) + fun getString(vararg ids: Int) = followPath(*ids, excludeLast = true)?.getString(ids.last()) + fun getVarInt(id: Int) = getRawValue(id) as? Long + fun getVarInt(vararg ids: Int) = followPath(*ids, excludeLast = true)?.getVarInt(ids.last()) + fun getCount(id: Int) = values[id]?.size ?: 0 + + fun getFixed64(id: Int): Long { + val bytes = getByteArray(id) ?: return 0L + var value = 0L + for (i in 0..7) { + value = value or ((bytes[i].toLong() and 0xFF) shl (i * 8)) + } + return value + } + + + fun getFixed32(id: Int): Int { + val bytes = getByteArray(id) ?: return 0 + var value = 0 + for (i in 0..3) { + value = value or ((bytes[i].toInt() and 0xFF) shl (i * 8)) } - each(id, reader) + return value } } \ No newline at end of file diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/ProtoWriter.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/ProtoWriter.kt @@ -23,43 +23,94 @@ class ProtoWriter { stream.write(v.toInt()) } - fun writeBuffer(id: Int, value: ByteArray) { - writeVarInt(id shl 3 or 2) + fun addBuffer(id: Int, value: ByteArray) { + writeVarInt(id shl 3 or WireType.LENGTH_DELIMITED.value) writeVarInt(value.size) stream.write(value) } - fun writeConstant(id: Int, value: Int) { - writeVarInt(id shl 3) - writeVarInt(value) - } + fun addVarInt(id: Int, value: Int) = addVarInt(id, value.toLong()) - fun writeConstant(id: Int, value: Long) { + fun addVarInt(id: Int, value: Long) { writeVarInt(id shl 3) writeVarLong(value) } - fun writeString(id: Int, value: String) = writeBuffer(id, value.toByteArray()) + fun addString(id: Int, value: String) = addBuffer(id, value.toByteArray()) + + fun addFixed32(id: Int, value: Int) { + writeVarInt(id shl 3 or WireType.FIXED32.value) + val bytes = ByteArray(4) + for (i in 0..3) { + bytes[i] = (value shr (i * 8)).toByte() + } + stream.write(bytes) + } + + fun addFixed64(id: Int, value: Long) { + writeVarInt(id shl 3 or WireType.FIXED64.value) + val bytes = ByteArray(8) + for (i in 0..7) { + bytes[i] = (value shr (i * 8)).toByte() + } + stream.write(bytes) + } - fun write(id: Int, writer: ProtoWriter.() -> Unit) { + fun from(id: Int, writer: ProtoWriter.() -> Unit) { val writerStream = ProtoWriter() writer(writerStream) - writeBuffer(id, writerStream.stream.toByteArray()) + addBuffer(id, writerStream.stream.toByteArray()) } - fun write(vararg ids: Int, writer: ProtoWriter.() -> Unit) { + fun from(vararg ids: Int, writer: ProtoWriter.() -> Unit) { val writerStream = ProtoWriter() writer(writerStream) var stream = writerStream.stream.toByteArray() ids.reversed().forEach { id -> with(ProtoWriter()) { - writeBuffer(id, stream) + addBuffer(id, stream) stream = this.stream.toByteArray() } } stream.let(this.stream::write) } + fun addWire(wire: Wire) { + writeVarInt(wire.id shl 3 or wire.type.value) + when (wire.type) { + WireType.VARINT -> writeVarLong(wire.value as Long) + WireType.FIXED64, WireType.FIXED32 -> { + when (wire.value) { + is Int -> { + val bytes = ByteArray(4) + for (i in 0..3) { + bytes[i] = (wire.value shr (i * 8)).toByte() + } + stream.write(bytes) + } + is Long -> { + val bytes = ByteArray(8) + for (i in 0..7) { + bytes[i] = (wire.value shr (i * 8)).toByte() + } + stream.write(bytes) + } + is ByteArray -> stream.write(wire.value) + } + } + WireType.LENGTH_DELIMITED -> { + val value = wire.value as ByteArray + writeVarInt(value.size) + stream.write(value) + } + WireType.START_GROUP -> { + val value = wire.value as ByteArray + stream.write(value) + } + WireType.END_GROUP -> return + } + } + fun toByteArray(): ByteArray { return stream.toByteArray() } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/WireType.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/protobuf/WireType.kt @@ -0,0 +1,14 @@ +package me.rhunk.snapenhance.util.protobuf; + +enum class WireType(val value: Int) { + VARINT(0), + FIXED64(1), + LENGTH_DELIMITED(2), + START_GROUP(3), + END_GROUP(4), + FIXED32(5); + + companion object { + fun fromValue(value: Int) = values().firstOrNull { it.value == value } + } +} diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/snap/EncryptionHelper.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/snap/EncryptionHelper.kt @@ -13,12 +13,12 @@ import javax.crypto.spec.SecretKeySpec object EncryptionHelper { fun getEncryptionKeys(contentType: ContentType, messageProto: ProtoReader, isArroyo: Boolean): Pair<ByteArray, ByteArray>? { val messageMediaInfo = MediaDownloaderHelper.getMessageMediaInfo(messageProto, contentType, isArroyo) ?: return null - val encryptionProtoIndex = if (messageMediaInfo.exists(Constants.ENCRYPTION_PROTO_INDEX_V2)) { + val encryptionProtoIndex = if (messageMediaInfo.contains(Constants.ENCRYPTION_PROTO_INDEX_V2)) { Constants.ENCRYPTION_PROTO_INDEX_V2 } else { Constants.ENCRYPTION_PROTO_INDEX } - val encryptionProto = messageMediaInfo.readPath(encryptionProtoIndex) ?: return null + val encryptionProto = messageMediaInfo.followPath(encryptionProtoIndex) ?: return null var key: ByteArray = encryptionProto.getByteArray(1)!! var iv: ByteArray = encryptionProto.getByteArray(2)!! diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/util/snap/MediaDownloaderHelper.kt b/core/src/main/kotlin/me/rhunk/snapenhance/util/snap/MediaDownloaderHelper.kt @@ -19,12 +19,12 @@ import java.util.zip.ZipInputStream object MediaDownloaderHelper { fun getMessageMediaInfo(protoReader: ProtoReader, contentType: ContentType, isArroyo: Boolean): ProtoReader? { - val messageContainerPath = if (isArroyo) protoReader.readPath(*Constants.ARROYO_MEDIA_CONTAINER_PROTO_PATH)!! else protoReader + val messageContainerPath = if (isArroyo) protoReader.followPath(*Constants.ARROYO_MEDIA_CONTAINER_PROTO_PATH)!! else protoReader val mediaContainerPath = if (contentType == ContentType.NOTE) intArrayOf(6, 1, 1) else intArrayOf(5, 1, 1) return when (contentType) { - ContentType.NOTE -> messageContainerPath.readPath(*mediaContainerPath) - ContentType.SNAP -> messageContainerPath.readPath(*(intArrayOf(11) + mediaContainerPath)) + ContentType.NOTE -> messageContainerPath.followPath(*mediaContainerPath) + ContentType.SNAP -> messageContainerPath.followPath(*(intArrayOf(11) + mediaContainerPath)) ContentType.EXTERNAL_MEDIA -> { val externalMediaTypes = arrayOf( intArrayOf(3, 3), //normal external media @@ -32,7 +32,7 @@ object MediaDownloaderHelper { intArrayOf(7, 3) //original story reply ) externalMediaTypes.forEach { path -> - messageContainerPath.readPath(*(path + mediaContainerPath))?.also { return it } + messageContainerPath.followPath(*(path + mediaContainerPath))?.also { return it } } null }