commit 9cb9bd7a26c3ebeab1a26e6b2d52cdb7edc4c46b
parent 5a47e04093d05b2f811a10475421a067220b46d8
Author: rhunk <101876869+rhunk@users.noreply.github.com>
Date:   Sat, 16 Sep 2023 11:56:41 +0200

feat: multiple media chat export
- optimize message exporter download
- optimize zip download/extract

Diffstat:
Mapp/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt | 57++++++++++++++++++++++-----------------------------------
Mcore/src/main/assets/web/export_template.html | 117+++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/download/data/MediaEncryptionKeyPair.kt | 25+++++++++++++++++--------
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/util/download/RemoteMediaResolver.kt | 31+++++++++++++++++++++++--------
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/util/export/MessageExporter.kt | 101+++++++++++++++++++++++++++++++++----------------------------------------------
Dcore/src/main/kotlin/me/rhunk/snapenhance/core/util/snap/EncryptionHelper.kt | 73-------------------------------------------------------------------------
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/util/snap/MediaDownloaderHelper.kt | 81+++++++++++++++++++++++++++----------------------------------------------------
Mcore/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/MediaDownloader.kt | 52+++++++++++++++++++---------------------------------
Mcore/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/decoder/MessageDecoder.kt | 44+++++++++++++++++++++++++++++++++++++++++++-
Mcore/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/Notifications.kt | 35++++++++++++++++++-----------------
10 files changed, 291 insertions(+), 325 deletions(-)

diff --git a/app/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt b/app/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt @@ -23,17 +23,14 @@ import me.rhunk.snapenhance.core.download.data.DownloadMetadata import me.rhunk.snapenhance.core.download.data.DownloadRequest import me.rhunk.snapenhance.core.download.data.DownloadStage import me.rhunk.snapenhance.core.download.data.InputMedia -import me.rhunk.snapenhance.core.download.data.MediaEncryptionKeyPair +import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver +import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import java.io.File import java.io.InputStream import java.net.HttpURLConnection import java.net.URL import java.util.zip.ZipInputStream -import javax.crypto.Cipher -import javax.crypto.CipherInputStream -import javax.crypto.spec.IvParameterSpec -import javax.crypto.spec.SecretKeySpec import javax.xml.parsers.DocumentBuilderFactory import javax.xml.transform.TransformerFactory import javax.xml.transform.dom.DOMSource @@ -110,14 +107,6 @@ class DownloadProcessor ( return files } - private fun decryptInputStream(inputStream: InputStream, encryption: MediaEncryptionKeyPair): InputStream { - val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") - val key = Base64.UrlSafe.decode(encryption.key) - val iv = Base64.UrlSafe.decode(encryption.iv) - cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(key, "AES"), IvParameterSpec(iv)) - return CipherInputStream(inputStream, cipher) - } - @SuppressLint("UnspecifiedRegisterReceiverFlag") private suspend fun saveMediaToGallery(inputFile: File, downloadObject: DownloadObject) { if (coroutineContext.job.isCancelled) return @@ -202,24 +191,16 @@ class DownloadProcessor ( downloadRequest.inputMedias.forEach { inputMedia -> fun handleInputStream(inputStream: InputStream) { createMediaTempFile().apply { - if (inputMedia.encryption != null) { - decryptInputStream(inputStream, - inputMedia.encryption!! - ).use { decryptedInputStream -> - decryptedInputStream.copyTo(outputStream()) - } - } else { - inputStream.copyTo(outputStream()) - } + (inputMedia.encryption?.decryptInputStream(inputStream) ?: inputStream).copyTo(outputStream()) }.also { downloadedMedias[inputMedia] = it } } launch { when (inputMedia.type) { DownloadMediaType.PROTO_MEDIA -> { - RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(inputMedia.content))?.let { inputStream -> - handleInputStream(inputStream) - } + RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(inputMedia.content), decryptionCallback = { it }, resultCallback = { + handleInputStream(it) + }) } DownloadMediaType.DIRECT_MEDIA -> { val decoded = Base64.UrlSafe.decode(inputMedia.content) @@ -359,20 +340,26 @@ class DownloadProcessor ( var shouldMergeOverlay = downloadRequest.shouldMergeOverlay //if there is a zip file, extract it and replace the downloaded media with the extracted ones - downloadedMedias.values.find { it.fileType == FileType.ZIP }?.let { entry -> - val extractedMedias = extractZip(entry.file.inputStream()).map { - InputMedia( - type = DownloadMediaType.LOCAL_MEDIA, - content = it.absolutePath - ) to DownloadedFile(it, FileType.fromFile(it)) + downloadedMedias.values.find { it.fileType == FileType.ZIP }?.let { zipFile -> + val oldDownloadedMedias = downloadedMedias.toMap() + downloadedMedias.clear() + + MediaDownloaderHelper.getSplitElements(zipFile.file.inputStream()) { type, inputStream -> + createMediaTempFile().apply { + inputStream.copyTo(outputStream()) + }.also { + downloadedMedias[InputMedia( + type = DownloadMediaType.LOCAL_MEDIA, + content = it.absolutePath, + isOverlay = type == SplitMediaAssetType.OVERLAY + )] = DownloadedFile(it, FileType.fromFile(it)) + } } - downloadedMedias.values.removeIf { - it.file.delete() - true + oldDownloadedMedias.forEach { (_, value) -> + value.file.delete() } - downloadedMedias.putAll(extractedMedias) shouldMergeOverlay = true } diff --git a/core/src/main/assets/web/export_template.html b/core/src/main/assets/web/export_template.html @@ -122,11 +122,16 @@ } - .media_container { + .chat_media { max-width: 300px; max-height: 500px; } + .overlay_media { + position: absolute; + pointer-events: none; + } + .red_snap_svg { color: var(--sigSnapWithoutSound); } @@ -140,7 +145,7 @@ <div style="display: none;"> <svg class="red_snap_svg" width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg"> <rect x="2.75" y="2.75" width="10.5" height="10.5" rx="1.808" stroke="currentColor" stroke-width="1.5"></rect> - </svg> + </svg> </div> <script> @@ -152,13 +157,24 @@ } function makeConversationSummary() { - const conversationTitle = conversationData.conversationName != null ? - conversationData.conversationName : - "DM with " + Object.values(participants).map(user => user.username).join(", ") + const conversationTitle = conversationData.conversationName != null ? + conversationData.conversationName : "DM with " + Object.values(participants).map(user => user.username).join(", ") document.querySelector(".conversation_summary .title").textContent = conversationTitle } + function decodeMedia(element) { + const decodedData = new Uint8Array( + inflate( + base64decode( + element.innerHTML.substring(5, element.innerHTML.length - 4) + ) + ) + ) + + return URL.createObjectURL(new Blob([decodedData])) + } + function makeConversationMessageContainer() { const messageTemplate = document.querySelector("#message_template") Object.values(conversationData.messages).forEach(message => { @@ -185,63 +201,88 @@ return headerElement })(document.createElement("div"))) - messageObject.appendChild(((elem) =>{ - elem.classList.add("content") + messageObject.appendChild(((messageContainer) =>{ + messageContainer.classList.add("content") - elem.innerHTML = message.serializedContent + messageContainer.innerHTML = message.serializedContent if (!message.serializedContent) { - elem.innerHTML = "" + messageContainer.innerHTML = "" let messageData = "" switch(message.type) { case "SNAP": - elem.appendChild(document.querySelector('.red_snap_svg').cloneNode(true)) + messageContainer.appendChild(document.querySelector('.red_snap_svg').cloneNode(true)) messageData = "Snap" break default: messageData = message.type - } - elem.innerHTML += messageData + messageContainer.innerHTML += messageData } - if (message.mediaReferences && message.mediaReferences.length > 0) { - //only get the first reference - const reference = Object.values(message.mediaReferences)[0] - let fetched = false - var observer = new IntersectionObserver(function(entries) { - if(!fetched && entries[0].isIntersecting === true) { - fetched = true + if (message.attachments && message.attachments.length > 0) { + let observers = [] + + message.attachments.forEach((attachment, index) => { + const mediaKey = attachment.key.replace(/(=)/g, "") + + observers.push(() => { + const originalMedia = document.querySelector('.media-ORIGINAL_' + mediaKey) + if (!originalMedia) { + return + } + + const originalMediaUrl = decodeMedia(originalMedia) - const mediaDiv = document.querySelector('.media-ORIGINAL_' + reference.content.replace(/(=)/g, "")) - if (!mediaDiv) return - - const content = mediaDiv.innerHTML.substring(5, mediaDiv.innerHTML.length - 4) - const decodedData = new Uint8Array(inflate(base64decode(content))) + const mediaContainer = document.createElement("div") + messageContainer.appendChild(mediaContainer) - const blob = new Blob([decodedData]) - const url = URL.createObjectURL(blob) - const imageTag = document.createElement("img") - imageTag.classList.add("media_container") - imageTag.src = url + imageTag.src = originalMediaUrl + imageTag.classList.add("chat_media") + mediaContainer.appendChild(imageTag) + imageTag.onerror = () => { - elem.removeChild(imageTag) + mediaContainer.removeChild(imageTag) const mediaTag = document.createElement(message.type === "NOTE" ? "audio" : "video") - mediaTag.classList.add("media_container") - mediaTag.src = url + mediaTag.classList.add("chat_media") + mediaTag.src = originalMediaUrl mediaTag.preload = "metadata" mediaTag.controls = true - elem.appendChild(mediaTag) + mediaContainer.appendChild(mediaTag) } - elem.innerHTML = "" - elem.appendChild(imageTag) + + const overlay = document.querySelector('.media-OVERLAY_' + mediaKey) + if (!overlay) { + return + } + + const overlayImage = document.createElement("img") + overlayImage.src = decodeMedia(overlay) + overlayImage.classList.add("chat_media") + overlayImage.classList.add("overlay_media") + mediaContainer.appendChild(overlayImage) + }) + }) + + let fetched = false + + new IntersectionObserver(entries => { + if(!fetched && entries[0].isIntersecting === true) { + fetched = true + messageContainer.innerHTML = "" + observers.forEach(c => { + try { + c() + } catch (e) { + console.log(e) + } + }) } - }, { threshold: [1] }); - observer.observe(elem) + }).observe(messageContainer) } - return elem + return messageContainer })(document.createElement("div"))) document.querySelector('.conversation_message_container').appendChild(messageObject) diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/download/data/MediaEncryptionKeyPair.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/download/data/MediaEncryptionKeyPair.kt @@ -3,6 +3,11 @@ package me.rhunk.snapenhance.core.download.data import me.rhunk.snapenhance.data.wrapper.impl.media.EncryptionWrapper +import java.io.InputStream +import javax.crypto.Cipher +import javax.crypto.CipherInputStream +import javax.crypto.spec.IvParameterSpec +import javax.crypto.spec.SecretKeySpec import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi @@ -10,12 +15,16 @@ import kotlin.io.encoding.ExperimentalEncodingApi data class MediaEncryptionKeyPair( val key: String, val iv: String -) - -fun Pair<ByteArray, ByteArray>.toKeyPair(): MediaEncryptionKeyPair { - return MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.first), Base64.UrlSafe.encode(this.second)) +) { + fun decryptInputStream(inputStream: InputStream): InputStream { + val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") + cipher.init(Cipher.DECRYPT_MODE, SecretKeySpec(Base64.UrlSafe.decode(key), "AES"), IvParameterSpec(Base64.UrlSafe.decode(iv))) + return CipherInputStream(inputStream, cipher) + } } -fun EncryptionWrapper.toKeyPair(): MediaEncryptionKeyPair { - return MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.keySpec), Base64.UrlSafe.encode(this.ivKeyParameterSpec)) -}- \ No newline at end of file +fun Pair<ByteArray, ByteArray>.toKeyPair() + = MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.first), Base64.UrlSafe.encode(this.second)) + +fun EncryptionWrapper.toKeyPair() + = MediaEncryptionKeyPair(Base64.UrlSafe.encode(this.keySpec), Base64.UrlSafe.encode(this.ivKeyParameterSpec))+ \ No newline at end of file diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/util/download/RemoteMediaResolver.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/util/download/RemoteMediaResolver.kt @@ -4,7 +4,6 @@ import me.rhunk.snapenhance.Constants import me.rhunk.snapenhance.core.Logger import okhttp3.OkHttpClient import okhttp3.Request -import java.io.ByteArrayInputStream import java.io.InputStream import java.util.Base64 @@ -36,18 +35,34 @@ object RemoteMediaResolver { } .build() - fun downloadBoltMedia(protoKey: ByteArray): InputStream? { - val request = Request.Builder() - .url(BOLT_HTTP_RESOLVER_URL + "/resolve?co=" + Base64.getUrlEncoder().encodeToString(protoKey)) - .addHeader("User-Agent", Constants.USER_AGENT) - .build() + private fun newResolveRequest(protoKey: ByteArray) = Request.Builder() + .url(BOLT_HTTP_RESOLVER_URL + "/resolve?co=" + Base64.getUrlEncoder().encodeToString(protoKey)) + .addHeader("User-Agent", Constants.USER_AGENT) + .build() - okHttpClient.newCall(request).execute().use { response -> + /** + * Download bolt media with memory allocation + */ + fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }): ByteArray? { + okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response -> if (!response.isSuccessful) { Logger.directDebug("Unexpected code $response") return null } - return ByteArrayInputStream(response.body.bytes()) + return decryptionCallback(response.body.byteStream()).readBytes() + } + } + + fun downloadBoltMedia(protoKey: ByteArray, decryptionCallback: (InputStream) -> InputStream = { it }, resultCallback: (InputStream) -> Unit) { + okHttpClient.newCall(newResolveRequest(protoKey)).execute().use { response -> + if (!response.isSuccessful) { + throw Throwable("invalid response ${response.code}") + } + resultCallback( + decryptionCallback( + response.body.byteStream() + ) + ) } } } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/util/export/MessageExporter.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/util/export/MessageExporter.kt @@ -3,6 +3,7 @@ package me.rhunk.snapenhance.core.util.export import android.os.Environment import android.util.Base64InputStream import com.google.gson.JsonArray +import com.google.gson.JsonNull import com.google.gson.JsonObject import de.robv.android.xposed.XposedHelpers import kotlinx.coroutines.Dispatchers @@ -11,20 +12,21 @@ import me.rhunk.snapenhance.ModContext import me.rhunk.snapenhance.core.BuildConfig import me.rhunk.snapenhance.core.database.objects.FriendFeedEntry import me.rhunk.snapenhance.core.database.objects.FriendInfo +import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver import me.rhunk.snapenhance.core.util.protobuf.ProtoReader -import me.rhunk.snapenhance.core.util.snap.EncryptionHelper import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import me.rhunk.snapenhance.data.ContentType import me.rhunk.snapenhance.data.FileType -import me.rhunk.snapenhance.data.MediaReferenceType import me.rhunk.snapenhance.data.wrapper.impl.Message import me.rhunk.snapenhance.data.wrapper.impl.SnapUUID +import me.rhunk.snapenhance.features.impl.downloader.decoder.AttachmentType +import me.rhunk.snapenhance.features.impl.downloader.decoder.MessageDecoder +import java.io.BufferedInputStream import java.io.File import java.io.FileOutputStream import java.io.InputStream import java.io.OutputStream import java.text.SimpleDateFormat -import java.util.Base64 import java.util.Collections import java.util.Date import java.util.Locale @@ -33,6 +35,7 @@ import java.util.concurrent.TimeUnit import java.util.zip.Deflater import java.util.zip.DeflaterInputStream import java.util.zip.ZipFile +import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi @@ -44,6 +47,7 @@ enum class ExportFormat( HTML("html"); } +@OptIn(ExperimentalEncodingApi::class) class MessageExporter( private val context: ModContext, private val outputFile: File, @@ -94,7 +98,6 @@ class MessageExporter( writer.flush() } - @OptIn(ExperimentalEncodingApi::class) suspend fun exportHtml(output: OutputStream) { val downloadMediaCacheFolder = File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS), "SnapEnhance/cache").also { it.mkdirs() } val mediaFiles = Collections.synchronizedMap(mutableMapOf<String, Pair<FileType, File>>()) @@ -115,34 +118,30 @@ class MessageExporter( mediaToDownload?.contains(it.messageContent.contentType) ?: false }.forEach { message -> threadPool.execute { - val remoteMediaReferences by lazy { - val serializedMessageContent = context.gson.toJsonTree(message.messageContent.instanceNonNull()).asJsonObject - serializedMessageContent["mRemoteMediaReferences"] - .asJsonArray.map { it.asJsonObject["mMediaReferences"].asJsonArray } - .flatten() - } - - remoteMediaReferences.firstOrNull().takeIf { it != null }?.let { media -> - val protoMediaReference = media.asJsonObject["mContentObject"].asJsonArray.map { it.asByte }.toByteArray() + MessageDecoder.decode(message.messageContent).forEach decode@{ attachment -> + val protoMediaReference = Base64.UrlSafe.decode(attachment.mediaKey ?: return@decode) runCatching { - val downloadedMedia = MediaDownloaderHelper.downloadMediaFromReference(protoMediaReference) { - EncryptionHelper.decryptInputStream(it, message.messageContent.contentType!!, ProtoReader(message.messageContent.content), isArroyo = false) - } - - downloadedMedia.forEach { (type, mediaData) -> - val fileType = FileType.fromByteArray(mediaData) - val fileName = "${type}_${kotlin.io.encoding.Base64.UrlSafe.encode(protoMediaReference).replace("=", "")}" - - val mediaFile = File(downloadMediaCacheFolder, "$fileName.${fileType.fileExtension}") - - FileOutputStream(mediaFile).use { fos -> - mediaData.inputStream().copyTo(fos) + RemoteMediaResolver.downloadBoltMedia(protoMediaReference, decryptionCallback = { + (attachment.attachmentInfo?.encryption?.decryptInputStream(it) ?: it) + }) { + it.use { inputStream -> + MediaDownloaderHelper.getSplitElements(inputStream) { type, splitInputStream -> + val fileName = "${type}_${Base64.UrlSafe.encode(protoMediaReference).replace("=", "")}" + val bufferedInputStream = BufferedInputStream(splitInputStream) + val fileType = MediaDownloaderHelper.getFileType(bufferedInputStream) + val mediaFile = File(downloadMediaCacheFolder, "$fileName.${fileType.fileExtension}") + + FileOutputStream(mediaFile).use { fos -> + bufferedInputStream.copyTo(fos) + } + + mediaFiles[fileName] = fileType to mediaFile + } } - - mediaFiles[fileName] = fileType to mediaFile - updateProgress("downloaded") } + + updateProgress("downloaded") }.onFailure { printLog("failed to download media for ${message.messageDescriptor.conversationId}_${message.orderKey}") context.log.error("failed to download media for ${message.messageDescriptor.conversationId}_${message.orderKey}", it) @@ -208,7 +207,7 @@ class MessageExporter( //export avenir next font apkFile.getEntry("res/font/avenir_next_medium.ttf").let { entry -> - val encodedFontData = kotlin.io.encoding.Base64.Default.encode(apkFile.getInputStream(entry).readBytes()) + val encodedFontData = Base64.Default.encode(apkFile.getInputStream(entry).readBytes()) output.write(""" <style> @font-face { @@ -284,41 +283,25 @@ class MessageExporter( addProperty("createdTimestamp", message.messageMetadata.createdAt) addProperty("readTimestamp", message.messageMetadata.readAt) addProperty("serializedContent", serializeMessageContent(message)) - addProperty("rawContent", Base64.getUrlEncoder().encodeToString(message.messageContent.content)) + addProperty("rawContent", Base64.UrlSafe.encode(message.messageContent.content)) - val messageContentType = message.messageContent.contentType ?: ContentType.CHAT - - EncryptionHelper.getEncryptionKeys(messageContentType, ProtoReader(message.messageContent.content), isArroyo = false)?.let { encryptionKeyPair -> - add("encryption", JsonObject().apply encryption@{ - addProperty("key", Base64.getEncoder().encodeToString(encryptionKeyPair.first)) - addProperty("iv", Base64.getEncoder().encodeToString(encryptionKeyPair.second)) - }) - } - - val remoteMediaReferences by lazy { - val serializedMessageContent = context.gson.toJsonTree(message.messageContent.instanceNonNull()).asJsonObject - serializedMessageContent["mRemoteMediaReferences"] - .asJsonArray.map { it.asJsonObject["mMediaReferences"].asJsonArray } - .flatten() - } - - add("mediaReferences", JsonArray().apply mediaReferences@ { - if (messageContentType != ContentType.EXTERNAL_MEDIA && - messageContentType != ContentType.STICKER && - messageContentType != ContentType.SNAP && - messageContentType != ContentType.NOTE) - return@mediaReferences - - remoteMediaReferences.forEach { media -> - val protoMediaReference = media.asJsonObject["mContentObject"].asJsonArray.map { it.asByte }.toByteArray() - val mediaType = MediaReferenceType.valueOf(media.asJsonObject["mMediaType"].asString) + add("attachments", JsonArray().apply { + MessageDecoder.decode(message.messageContent) + .forEach attachments@{ attachments -> + if (attachments.type == AttachmentType.STICKER) //TODO: implement stickers + return@attachments add(JsonObject().apply { - addProperty("mediaType", mediaType.toString()) - addProperty("content", Base64.getUrlEncoder().encodeToString(protoMediaReference)) + addProperty("key", attachments.mediaKey?.replace("=", "")) + addProperty("type", attachments.type.toString()) + add("encryption", attachments.attachmentInfo?.encryption?.let { encryption -> + JsonObject().apply { + addProperty("key", encryption.key) + addProperty("iv", encryption.iv) + } + } ?: JsonNull.INSTANCE) }) } }) - }) } }) diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/util/snap/EncryptionHelper.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/util/snap/EncryptionHelper.kt @@ -1,73 +0,0 @@ -package me.rhunk.snapenhance.core.util.snap - -import me.rhunk.snapenhance.Constants -import me.rhunk.snapenhance.core.download.data.MediaEncryptionKeyPair -import me.rhunk.snapenhance.core.util.protobuf.ProtoReader -import me.rhunk.snapenhance.data.ContentType -import java.io.InputStream -import javax.crypto.Cipher -import javax.crypto.CipherInputStream -import javax.crypto.spec.IvParameterSpec -import javax.crypto.spec.SecretKeySpec -import kotlin.io.encoding.Base64 -import kotlin.io.encoding.ExperimentalEncodingApi - -@OptIn(ExperimentalEncodingApi::class) -object EncryptionHelper { - fun getEncryptionKeys(contentType: ContentType, messageProto: ProtoReader, isArroyo: Boolean): Pair<ByteArray, ByteArray>? { - val mediaEncryptionInfo = MediaDownloaderHelper.getMessageMediaEncryptionInfo( - messageProto, - contentType, - isArroyo - ) ?: return null - val encryptionProtoIndex = if (mediaEncryptionInfo.contains(Constants.ENCRYPTION_PROTO_INDEX_V2)) { - Constants.ENCRYPTION_PROTO_INDEX_V2 - } else { - Constants.ENCRYPTION_PROTO_INDEX - } - val encryptionProto = mediaEncryptionInfo.followPath(encryptionProtoIndex) ?: return null - - var key: ByteArray = encryptionProto.getByteArray(1)!! - var iv: ByteArray = encryptionProto.getByteArray(2)!! - - if (encryptionProtoIndex == Constants.ENCRYPTION_PROTO_INDEX_V2) { - key = Base64.UrlSafe.decode(key) - iv = Base64.UrlSafe.decode(iv) - } - - return Pair(key, iv) - } - - fun decryptInputStream( - inputStream: InputStream, - contentType: ContentType, - messageProto: ProtoReader, - isArroyo: Boolean - ): InputStream { - val encryptionKeys = getEncryptionKeys(contentType, messageProto, isArroyo) ?: throw Exception("Failed to get encryption keys") - - Cipher.getInstance("AES/CBC/PKCS5Padding").apply { - init(Cipher.DECRYPT_MODE, SecretKeySpec(encryptionKeys.first, "AES"), IvParameterSpec(encryptionKeys.second)) - }.let { cipher -> - return CipherInputStream(inputStream, cipher) - } - } - - fun decryptInputStream( - inputStream: InputStream, - mediaEncryptionKeyPair: MediaEncryptionKeyPair? - ): InputStream { - if (mediaEncryptionKeyPair == null) { - return inputStream - } - - Cipher.getInstance("AES/CBC/PKCS5Padding").apply { - init(Cipher.DECRYPT_MODE, - SecretKeySpec(Base64.UrlSafe.decode(mediaEncryptionKeyPair.key), "AES"), - IvParameterSpec(Base64.UrlSafe.decode(mediaEncryptionKeyPair.iv)) - ) - }.let { cipher -> - return CipherInputStream(inputStream, cipher) - } - } -} diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/util/snap/MediaDownloaderHelper.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/util/snap/MediaDownloaderHelper.kt @@ -1,70 +1,45 @@ package me.rhunk.snapenhance.core.util.snap -import me.rhunk.snapenhance.Constants import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType -import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver -import me.rhunk.snapenhance.core.util.protobuf.ProtoReader -import me.rhunk.snapenhance.data.ContentType import me.rhunk.snapenhance.data.FileType -import java.io.ByteArrayInputStream -import java.io.FileNotFoundException +import java.io.BufferedInputStream import java.io.InputStream +import java.util.zip.ZipEntry import java.util.zip.ZipInputStream object MediaDownloaderHelper { - fun getMessageMediaEncryptionInfo(protoReader: ProtoReader, contentType: ContentType, isArroyo: Boolean): ProtoReader? { - val messageContainerPath = if (isArroyo) protoReader.followPath(*Constants.ARROYO_MEDIA_CONTAINER_PROTO_PATH)!! else protoReader - val mediaContainerPath = if (contentType == ContentType.NOTE) intArrayOf(6, 1, 1) else intArrayOf(5, 1, 1) + fun getFileType(bufferedInputStream: BufferedInputStream): FileType { + val buffer = ByteArray(16) + bufferedInputStream.mark(16) + bufferedInputStream.read(buffer) + bufferedInputStream.reset() + return FileType.fromByteArray(buffer) + } - return when (contentType) { - ContentType.NOTE -> messageContainerPath.followPath(*mediaContainerPath) - ContentType.SNAP -> messageContainerPath.followPath(*(intArrayOf(11) + mediaContainerPath)) - ContentType.EXTERNAL_MEDIA -> { - val externalMediaTypes = arrayOf( - intArrayOf(3, 3, *mediaContainerPath), //normal external media - intArrayOf(7, 15, 1, 1), //attached audio note - intArrayOf(7, 12, 3, *mediaContainerPath), //attached story reply - intArrayOf(7, 3, *mediaContainerPath), //original story reply - ) - externalMediaTypes.forEach { path -> - messageContainerPath.followPath(*path)?.also { return it } - } - null - } - else -> null + + fun getSplitElements( + inputStream: InputStream, + callback: (SplitMediaAssetType, InputStream) -> Unit + ) { + val bufferedInputStream = BufferedInputStream(inputStream) + val fileType = getFileType(bufferedInputStream) + + if (fileType != FileType.ZIP) { + callback(SplitMediaAssetType.ORIGINAL, bufferedInputStream) + return } - } - fun downloadMediaFromReference( - mediaReference: ByteArray, - decryptionCallback: (InputStream) -> InputStream, - ): Map<SplitMediaAssetType, ByteArray> { - val inputStream = RemoteMediaResolver.downloadBoltMedia(mediaReference) ?: throw FileNotFoundException("Unable to get media key. Check the logs for more info") - val content = decryptionCallback(inputStream).readBytes() - val fileType = FileType.fromByteArray(content) - val isZipFile = fileType == FileType.ZIP + val zipInputStream = ZipInputStream(bufferedInputStream) - //videos with overlay are packed in a zip file - //there are 2 files in the zip file, the video (webm) and the overlay (png) - if (isZipFile) { - var videoData: ByteArray? = null - var overlayData: ByteArray? = null - val zipInputStream = ZipInputStream(ByteArrayInputStream(content)) - while (zipInputStream.nextEntry != null) { - val zipEntryData: ByteArray = zipInputStream.readBytes() - val entryFileType = FileType.fromByteArray(zipEntryData) - if (entryFileType.isVideo) { - videoData = zipEntryData - } else if (entryFileType.isImage) { - overlayData = zipEntryData - } + var entry: ZipEntry? = zipInputStream.nextEntry + while (entry != null) { + if (entry.name.startsWith("overlay")) { + callback(SplitMediaAssetType.OVERLAY, zipInputStream) + } else if (entry.name.startsWith("media")) { + callback(SplitMediaAssetType.ORIGINAL, zipInputStream) } - videoData ?: throw FileNotFoundException("Unable to find video file in zip file") - overlayData ?: throw FileNotFoundException("Unable to find overlay file in zip file") - return mapOf(SplitMediaAssetType.ORIGINAL to videoData, SplitMediaAssetType.OVERLAY to overlayData) + entry = zipInputStream.nextEntry } - - return mapOf(SplitMediaAssetType.ORIGINAL to content) } } \ No newline at end of file diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/MediaDownloader.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/MediaDownloader.kt @@ -28,7 +28,6 @@ import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver import me.rhunk.snapenhance.core.util.ktx.getObjectField import me.rhunk.snapenhance.core.util.protobuf.ProtoReader import me.rhunk.snapenhance.core.util.snap.BitmojiSelfie -import me.rhunk.snapenhance.core.util.snap.EncryptionHelper import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import me.rhunk.snapenhance.core.util.snap.PreviewUtils import me.rhunk.snapenhance.data.FileType @@ -47,6 +46,7 @@ import me.rhunk.snapenhance.hook.HookAdapter import me.rhunk.snapenhance.hook.HookStage import me.rhunk.snapenhance.hook.Hooker import me.rhunk.snapenhance.ui.ViewAppearanceHelper +import java.io.ByteArrayInputStream import java.nio.file.Paths import java.text.SimpleDateFormat import java.util.Locale @@ -526,42 +526,24 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp val friendInfo: FriendInfo = context.database.getFriendInfo(message.senderId!!) ?: throw Exception("Friend not found in database") val authorName = friendInfo.usernameForSorting!! - var messageContent = message.messageContent!! - var customMediaReferences = mutableListOf<String>() - - if (messageLogger.isMessageRemoved(message.clientConversationId!!, message.serverMessageId.toLong())) { + val decodedAttachments = if (messageLogger.isMessageRemoved(message.clientConversationId!!, message.serverMessageId.toLong())) { val messageObject = messageLogger.getMessageObject(message.clientConversationId!!, message.serverMessageId.toLong()) ?: throw Exception("Message not found in database") - val messageContentObject = messageObject.getAsJsonObject("mMessageContent") - - messageContent = messageContentObject - .getAsJsonArray("mContent") - .map { it.asByte } - .toByteArray() - - customMediaReferences = messageContentObject - .getAsJsonArray("mRemoteMediaReferences") - .map { it.asJsonObject.getAsJsonArray("mMediaReferences") } - .flatten().map { reference -> - Base64.UrlSafe.encode( - reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray() - ) - } - .toMutableList() + MessageDecoder.decode(messageObject.getAsJsonObject("mMessageContent")) + } else { + MessageDecoder.decode( + protoReader = ProtoReader(message.messageContent!!) + ) } - val messageReader = ProtoReader(messageContent) - val decodedAttachments = MessageDecoder.decode( - protoReader = messageReader, - customMediaReferences = customMediaReferences.takeIf { it.isNotEmpty() } - ) - if (decodedAttachments.isEmpty()) { context.shortToast(translations["no_attachments_toast"]) return } if (!isPreview) { - if (decodedAttachments.size == 1) { + if (decodedAttachments.size == 1 || + context.mainActivity == null // we can't show alert dialogs when it downloads from a notification, so it downloads the first one + ) { downloadMessageAttachments(friendInfo, message, authorName, listOf(decodedAttachments.first()) ) @@ -600,11 +582,15 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp val firstAttachment = decodedAttachments.first() val previewCoroutine = async { - val downloadedMediaList = MediaDownloaderHelper.downloadMediaFromReference(Base64.decode(firstAttachment.mediaKey!!)) { - EncryptionHelper.decryptInputStream( - it, - decodedAttachments.first().attachmentInfo?.encryption - ) + val downloadedMedia = RemoteMediaResolver.downloadBoltMedia(Base64.decode(firstAttachment.mediaKey!!), decryptionCallback = { + firstAttachment.attachmentInfo?.encryption?.decryptInputStream(it) ?: it + }) ?: return@async null + + val downloadedMediaList = mutableMapOf<SplitMediaAssetType, ByteArray>() + + MediaDownloaderHelper.getSplitElements(ByteArrayInputStream(downloadedMedia)) { + type, inputStream -> + downloadedMediaList[type] = inputStream.readBytes() } val originalMedia = downloadedMediaList[SplitMediaAssetType.ORIGINAL] ?: return@async null diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/decoder/MessageDecoder.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/downloader/decoder/MessageDecoder.kt @@ -1,7 +1,11 @@ package me.rhunk.snapenhance.features.impl.downloader.decoder +import com.google.gson.GsonBuilder +import com.google.gson.JsonElement +import com.google.gson.JsonObject import me.rhunk.snapenhance.core.download.data.toKeyPair import me.rhunk.snapenhance.core.util.protobuf.ProtoReader +import me.rhunk.snapenhance.data.wrapper.impl.MessageContent import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi @@ -13,6 +17,8 @@ data class DecodedAttachment( @OptIn(ExperimentalEncodingApi::class) object MessageDecoder { + private val gson = GsonBuilder().create() + private fun decodeAttachment(protoReader: ProtoReader): AttachmentInfo? { val mediaInfo = protoReader.followPath(1, 1) ?: return null @@ -39,6 +45,43 @@ object MessageDecoder { ) } + @OptIn(ExperimentalEncodingApi::class) + fun getEncodedMediaReferences(messageContent: JsonElement): List<String> { + return getMediaReferences(messageContent).map { reference -> + Base64.UrlSafe.encode( + reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray() + ) + } + .toList() + } + + fun getMediaReferences(messageContent: JsonElement): List<JsonElement> { + return messageContent.asJsonObject.getAsJsonArray("mRemoteMediaReferences") + .asSequence() + .map { it.asJsonObject.getAsJsonArray("mMediaReferences") } + .flatten() + .sortedBy { + it.asJsonObject["mMediaListId"].asLong + }.toList() + } + + + fun decode(messageContent: MessageContent): List<DecodedAttachment> { + return decode( + ProtoReader(messageContent.content), + customMediaReferences = getEncodedMediaReferences(gson.toJsonTree(messageContent.instanceNonNull())) + ) + } + + fun decode(messageContent: JsonObject): List<DecodedAttachment> { + return decode( + ProtoReader(messageContent.getAsJsonArray("mContent") + .map { it.asByte } + .toByteArray()), + customMediaReferences = getEncodedMediaReferences(messageContent) + ) + } + fun decode( protoReader: ProtoReader, customMediaReferences: List<String>? = null // when customReferences is null it means that the message is from arroyo database @@ -138,7 +181,6 @@ object MessageDecoder { } } - return decodedAttachment } } \ No newline at end of file diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/Notifications.kt b/core/src/main/kotlin/me/rhunk/snapenhance/features/impl/tweaks/Notifications.kt @@ -16,9 +16,9 @@ import me.rhunk.snapenhance.core.Logger import me.rhunk.snapenhance.core.download.data.SplitMediaAssetType import me.rhunk.snapenhance.core.eventbus.events.impl.SnapWidgetBroadcastReceiveEvent import me.rhunk.snapenhance.core.util.CallbackBuilder +import me.rhunk.snapenhance.core.util.download.RemoteMediaResolver import me.rhunk.snapenhance.core.util.ktx.setObjectField import me.rhunk.snapenhance.core.util.protobuf.ProtoReader -import me.rhunk.snapenhance.core.util.snap.EncryptionHelper import me.rhunk.snapenhance.core.util.snap.MediaDownloaderHelper import me.rhunk.snapenhance.core.util.snap.PreviewUtils import me.rhunk.snapenhance.core.util.snap.SnapWidgetBroadcastReceiverHelper @@ -34,7 +34,6 @@ import me.rhunk.snapenhance.features.impl.downloader.decoder.MessageDecoder import me.rhunk.snapenhance.hook.HookStage import me.rhunk.snapenhance.hook.Hooker import me.rhunk.snapenhance.hook.hook -import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi class Notifications : Feature("Notifications", loadParams = FeatureLoadParams.INIT_SYNC) { @@ -246,29 +245,31 @@ class Notifications : Feature("Notifications", loadParams = FeatureLoadParams.IN appendNotifications() } ContentType.SNAP, ContentType.EXTERNAL_MEDIA -> { - val serializedMessageContent = context.gson.toJsonTree(snapMessage.messageContent.instanceNonNull()).asJsonObject - val mediaReferences = serializedMessageContent - .getAsJsonArray("mRemoteMediaReferences") - .map { it.asJsonObject.getAsJsonArray("mMediaReferences") } - .flatten() + val mediaReferences = MessageDecoder.getMediaReferences( + messageContent = context.gson.toJsonTree(snapMessage.messageContent.instanceNonNull()) + ) - val mediaReferenceUrls = mediaReferences.map { reference -> + val mediaReferenceKeys = mediaReferences.map { reference -> reference.asJsonObject.getAsJsonArray("mContentObject").map { it.asByte }.toByteArray() } - MessageDecoder.decode( - ProtoReader(contentData), - customMediaReferences = mediaReferenceUrls.map { Base64.UrlSafe.encode(it) } - ).forEachIndexed { index, media -> - val mediaType = MediaReferenceType.valueOf(mediaReferences[index].asJsonObject["mMediaType"].asString) + MessageDecoder.decode(snapMessage.messageContent).firstOrNull()?.also { media -> + val mediaType = MediaReferenceType.valueOf(mediaReferences.first().asJsonObject["mMediaType"].asString) + runCatching { - val downloadedMediaList = MediaDownloaderHelper.downloadMediaFromReference(mediaReferenceUrls[index]) { inputStream -> - media.attachmentInfo?.encryption?.let { EncryptionHelper.decryptInputStream(inputStream, it) } ?: inputStream + val downloadedMedia = RemoteMediaResolver.downloadBoltMedia(mediaReferenceKeys.first(), decryptionCallback = { + media.attachmentInfo?.encryption?.decryptInputStream(it) ?: it + }) ?: throw Throwable("Unable to download media") + + val downloadedMedias = mutableMapOf<SplitMediaAssetType, ByteArray>() + + MediaDownloaderHelper.getSplitElements(downloadedMedia.inputStream()) { type, inputStream -> + downloadedMedias[type] = inputStream.readBytes() } - var bitmapPreview = PreviewUtils.createPreview(downloadedMediaList[SplitMediaAssetType.ORIGINAL]!!, mediaType.name.contains("VIDEO"))!! + var bitmapPreview = PreviewUtils.createPreview(downloadedMedias[SplitMediaAssetType.ORIGINAL]!!, mediaType.name.contains("VIDEO"))!! - downloadedMediaList[SplitMediaAssetType.OVERLAY]?.let { + downloadedMedias[SplitMediaAssetType.OVERLAY]?.let { bitmapPreview = PreviewUtils.mergeBitmapOverlay(bitmapPreview, BitmapFactory.decodeByteArray(it, 0, it.size)) }