commit f375f9bc0e1f6ca77e27473464e4b81af1d094ff
parent 4046d1a50658add099fb33396ff1a93fed012ac5
Author: rhunk <101876869+rhunk@users.noreply.github.com>
Date:   Wed, 29 Nov 2023 23:56:55 +0100

fix(downloader): media identifier & dash chapter selector
- fix ffmpeg crashes
- fix close resources
- perf http server

Diffstat:
Mapp/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt | 71++++++++++++++++++++++++++++++++++++++++-------------------------------
Mapp/src/main/kotlin/me/rhunk/snapenhance/download/FFMpegProcessor.kt | 2++
Mcommon/src/main/kotlin/me/rhunk/snapenhance/common/config/impl/DownloaderConfig.kt | 2+-
Mcommon/src/main/kotlin/me/rhunk/snapenhance/common/data/download/DownloadMetadata.kt | 2+-
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/DownloadManagerClient.kt | 6+++---
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/Stories.kt | 7+++++--
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/downloader/MediaDownloader.kt | 53+++++++++++++++++++++++++++++++----------------------
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/util/media/HttpServer.kt | 109++++++++++++++++++++++++++++++++++++++++++++-----------------------------------
8 files changed, 144 insertions(+), 108 deletions(-)

diff --git a/app/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt b/app/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt @@ -122,10 +122,9 @@ class DownloadProcessor ( } pendingTask.updateProgress("Converting image to $format") - val outputStream = inputFile.outputStream() - bitmap.compress(compressFormat, 100, outputStream) - outputStream.close() - + inputFile.outputStream().use { + bitmap.compress(compressFormat, 100, it) + } fileType = FileType.fromFile(inputFile) } } @@ -146,11 +145,12 @@ class DownloadProcessor ( } val outputFile = outputFileFolder.createFile(fileType.mimeType, fileName)!! - val outputStream = remoteSideContext.androidContext.contentResolver.openOutputStream(outputFile.uri)!! pendingTask.updateProgress("Saving media to gallery") - inputFile.inputStream().use { inputStream -> - inputStream.copyTo(outputStream) + remoteSideContext.androidContext.contentResolver.openOutputStream(outputFile.uri)!!.use { outputStream -> + inputFile.inputStream().use { inputStream -> + inputStream.copyTo(outputStream) + } } pendingTask.task.extra = outputFile.uri.toString() @@ -201,19 +201,20 @@ class DownloadProcessor ( fun handleInputStream(inputStream: InputStream, estimatedSize: Long = 0L) { createMediaTempFile().apply { val decryptedInputStream = (inputMedia.encryption?.decryptInputStream(inputStream) ?: inputStream).buffered() - val outputStream = outputStream() val buffer = ByteArray(DEFAULT_BUFFER_SIZE) var read: Int var totalRead = 0L var lastTotalRead = 0L - while (decryptedInputStream.read(buffer).also { read = it } != -1) { - outputStream.write(buffer, 0, read) - totalRead += read - inputMediaDownloadedBytes[inputMedia] = totalRead - if (totalRead - lastTotalRead > 1024 * 1024) { - setProgress("${totalRead / 1024}KB/${estimatedSize / 1024}KB") - lastTotalRead = totalRead + outputStream().use { outputStream -> + while (decryptedInputStream.read(buffer).also { read = it } != -1) { + outputStream.write(buffer, 0, read) + totalRead += read + inputMediaDownloadedBytes[inputMedia] = totalRead + if (totalRead - lastTotalRead > 1024 * 1024) { + setProgress("${totalRead / 1024}KB/${estimatedSize / 1024}KB") + lastTotalRead = totalRead + } } } }.also { downloadedMedias[inputMedia] = it } @@ -224,7 +225,9 @@ class DownloadProcessor ( DownloadMediaType.PROTO_MEDIA -> { RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(inputMedia.content), decryptionCallback = { it }, resultCallback = { inputStream, length -> totalSize += length - handleInputStream(inputStream, estimatedSize = length) + inputStream.use { + handleInputStream(it, estimatedSize = length) + } }) } DownloadMediaType.REMOTE_MEDIA -> { @@ -233,7 +236,9 @@ class DownloadProcessor ( setRequestProperty("User-Agent", Constants.USER_AGENT) connect() totalSize += contentLength.toLong() - handleInputStream(inputStream, estimatedSize = contentLength.toLong()) + inputStream.use { + handleInputStream(it, estimatedSize = contentLength.toLong()) + } } } DownloadMediaType.DIRECT_MEDIA -> { @@ -292,8 +297,9 @@ class DownloadProcessor ( val dashOptions = downloadRequest.dashOptions!! val dashPlaylistFile = renameFromFileType(media.file, FileType.MPD) - val xmlData = dashPlaylistFile.outputStream() - TransformerFactory.newInstance().newTransformer().transform(DOMSource(playlistXml), StreamResult(xmlData)) + dashPlaylistFile.outputStream().use { + TransformerFactory.newInstance().newTransformer().transform(DOMSource(playlistXml), StreamResult(it)) + } callbackOnProgress(translation.format("download_toast", "path" to dashPlaylistFile.nameWithoutExtension)) val outputFile = File.createTempFile("dash", ".mp4") @@ -329,9 +335,8 @@ class DownloadProcessor ( remoteSideContext.coroutineScope.launch { val downloadMetadata = gson.fromJson(intent.getStringExtra(ReceiversConfig.DOWNLOAD_METADATA_EXTRA)!!, DownloadMetadata::class.java) val downloadRequest = gson.fromJson(intent.getStringExtra(ReceiversConfig.DOWNLOAD_REQUEST_EXTRA)!!, DownloadRequest::class.java) - val downloadId = (downloadMetadata.mediaIdentifier ?: UUID.randomUUID().toString()).longHashCode().absoluteValue.toString(16) - remoteSideContext.taskManager.getTaskByHash(downloadId)?.let { task -> + remoteSideContext.taskManager.getTaskByHash(downloadMetadata.mediaIdentifier)?.let { task -> remoteSideContext.log.debug("already queued or downloaded") if (task.status.isFinalStage()) { @@ -348,7 +353,7 @@ class DownloadProcessor ( Task( type = TaskType.DOWNLOAD, title = downloadMetadata.downloadSource + " (" + downloadMetadata.mediaAuthor + ")", - hash = downloadId + hash = downloadMetadata.mediaIdentifier ) ).apply { status = TaskStatus.RUNNING @@ -372,15 +377,19 @@ class DownloadProcessor ( val oldDownloadedMedias = downloadedMedias.toMap() downloadedMedias.clear() - MediaDownloaderHelper.getSplitElements(zipFile.file.inputStream()) { type, inputStream -> - createMediaTempFile().apply { - inputStream.copyTo(outputStream()) - }.also { - downloadedMedias[InputMedia( - type = DownloadMediaType.LOCAL_MEDIA, - content = it.absolutePath, - isOverlay = type == SplitMediaAssetType.OVERLAY - )] = DownloadedFile(it, FileType.fromFile(it)) + zipFile.file.inputStream().use { zipFileInputStream -> + MediaDownloaderHelper.getSplitElements(zipFileInputStream) { type, inputStream -> + createMediaTempFile().apply { + outputStream().use { + inputStream.copyTo(it) + } + }.also { + downloadedMedias[InputMedia( + type = DownloadMediaType.LOCAL_MEDIA, + content = it.absolutePath, + isOverlay = type == SplitMediaAssetType.OVERLAY + )] = DownloadedFile(it, FileType.fromFile(it)) + } } } diff --git a/app/src/main/kotlin/me/rhunk/snapenhance/download/FFMpegProcessor.kt b/app/src/main/kotlin/me/rhunk/snapenhance/download/FFMpegProcessor.kt @@ -94,6 +94,8 @@ class FFMpegProcessor( } suspend fun execute(args: Request) { + // load ffmpeg native sync to avoid native crash + synchronized(this) { FFmpegKit.listSessions() } val globalArguments = ArgumentList().apply { this += "-y" this += "-threads" to ffmpegOptions.threads.get().toString() diff --git a/common/src/main/kotlin/me/rhunk/snapenhance/common/config/impl/DownloaderConfig.kt b/common/src/main/kotlin/me/rhunk/snapenhance/common/config/impl/DownloaderConfig.kt @@ -46,6 +46,6 @@ class DownloaderConfig : ConfigContainer() { val chatDownloadContextMenu = boolean("chat_download_context_menu") val ffmpegOptions = container("ffmpeg_options", FFMpegOptions()) { addNotices(FeatureNotice.UNSTABLE) } val logging = multiple("logging", "started", "success", "progress", "failure").apply { - set(mutableListOf("started", "success")) + set(mutableListOf("success", "progress", "failure")) } } \ No newline at end of file diff --git a/common/src/main/kotlin/me/rhunk/snapenhance/common/data/download/DownloadMetadata.kt b/common/src/main/kotlin/me/rhunk/snapenhance/common/data/download/DownloadMetadata.kt @@ -1,7 +1,7 @@ package me.rhunk.snapenhance.common.data.download data class DownloadMetadata( - val mediaIdentifier: String?, + val mediaIdentifier: String, val outputPath: String, val mediaAuthor: String?, val downloadSource: String, diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/DownloadManagerClient.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/DownloadManagerClient.kt @@ -26,9 +26,9 @@ class DownloadManagerClient ( DownloadRequest( inputMedias = arrayOf( InputMedia( - content = playlistUrl, - type = DownloadMediaType.REMOTE_MEDIA - ) + content = playlistUrl, + type = DownloadMediaType.REMOTE_MEDIA + ) ), dashOptions = DashOptions(offsetTime, duration), flags = DownloadRequest.Flags.IS_DASH_PLAYLIST diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/Stories.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/Stories.kt @@ -17,8 +17,11 @@ class Stories : Feature("Stories", loadParams = FeatureLoadParams.ACTIVITY_CREAT fun cancelRequest() { runBlocking { suspendCoroutine { - context.httpServer.ensureServerStarted { - event.url = "http://127.0.0.1:${context.httpServer.port}" + context.httpServer.ensureServerStarted()?.let { server -> + event.url = "http://127.0.0.1:${server.port}" + it.resumeWith(Result.success(Unit)) + } ?: run { + event.canceled = true it.resumeWith(Result.success(Unit)) } } diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/downloader/MediaDownloader.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/downloader/MediaDownloader.kt @@ -25,6 +25,7 @@ import me.rhunk.snapenhance.common.data.download.MediaDownloadSource import me.rhunk.snapenhance.common.data.download.SplitMediaAssetType import me.rhunk.snapenhance.common.database.impl.ConversationMessage import me.rhunk.snapenhance.common.database.impl.FriendInfo +import me.rhunk.snapenhance.common.util.ktx.longHashCode import me.rhunk.snapenhance.common.util.protobuf.ProtoReader import me.rhunk.snapenhance.common.util.snap.BitmojiSelfie import me.rhunk.snapenhance.common.util.snap.MediaDownloaderHelper @@ -54,9 +55,11 @@ import java.io.ByteArrayInputStream import java.nio.file.Paths import java.text.SimpleDateFormat import java.util.Locale +import java.util.UUID import kotlin.coroutines.suspendCoroutine import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi +import kotlin.math.absoluteValue private fun String.sanitizeForPath(): String { return this.replace(" ", "_") @@ -85,7 +88,11 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp downloadSource: MediaDownloadSource, friendInfo: FriendInfo? = null ): DownloadManagerClient { - val generatedHash = mediaIdentifier.hashCode().toString(16).replaceFirst("-", "") + val generatedHash = ( + if (!context.config.downloader.allowDuplicate.get()) mediaIdentifier + else UUID.randomUUID().toString() + ).longHashCode().absoluteValue.toString(16) + val iconUrl = BitmojiSelfie.getBitmojiSelfie(friendInfo?.bitmojiSelfieId, friendInfo?.bitmojiAvatarId, BitmojiSelfie.BitmojiSelfieType.THREE_D) val downloadLogging by context.config.downloader.logging @@ -98,9 +105,7 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp return DownloadManagerClient( context = context, metadata = DownloadMetadata( - mediaIdentifier = if (!context.config.downloader.allowDuplicate.get()) { - generatedHash - } else null, + mediaIdentifier = generatedHash, mediaAuthor = mediaAuthor, downloadSource = downloadSource.key, iconUrl = iconUrl, @@ -161,7 +166,7 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp finalPath.append(downloadSource.pathName).append("/") } if (pathFormat.contains("append_hash")) { - appendFileName(hexHash) + appendFileName(hexHash.substring(0, hexHash.length.coerceAtMost(8))) } if (pathFormat.contains("append_source")) { appendFileName(downloadSource.pathName) @@ -228,10 +233,12 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp Uri.parse(path).let { uri -> if (uri.scheme == "file") { return@let suspendCoroutine<String> { continuation -> - context.httpServer.ensureServerStarted { + context.httpServer.ensureServerStarted()?.let { server -> val file = Paths.get(uri.path).toFile() - val url = putDownloadableContent(file.inputStream(), file.length()) + val url = server.putDownloadableContent(file.inputStream(), file.length()) continuation.resumeWith(Result.success(url)) + } ?: run { + continuation.resumeWith(Result.failure(Exception("Failed to start http server"))) } } } @@ -426,7 +433,12 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp setTitle("Download dash media") setMultiChoiceItems( chapters.map { "Segment ${prettyPrintTime(it.offset)} - ${prettyPrintTime(it.offset + (it.duration ?: 0))}" }.toTypedArray(), - List(chapters.size) { index -> currentChapterIndex == index }.toBooleanArray() + List(chapters.size) { index -> + if (currentChapterIndex == index) { + selectedChapters.add(index) + true + } else false + }.toBooleanArray() ) { _, which, isChecked -> if (isChecked) { selectedChapters.add(which) @@ -444,22 +456,19 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp } setPositiveButton("Download") { _, _ -> val groups = mutableListOf<MutableList<SnapChapterInfo>>() - var currentGroup = mutableListOf<SnapChapterInfo>() - var lastChapterIndex = -1 - //check for consecutive chapters - chapters.filterIndexed { index, _ -> selectedChapters.contains(index) } - .forEachIndexed { index, pair -> - if (lastChapterIndex != -1 && index != lastChapterIndex + 1) { - groups.add(currentGroup) - currentGroup = mutableListOf() + var lastChapterIndex = -1 + // group consecutive chapters + chapters.forEachIndexed { index, snapChapter -> + lastChapterIndex = if (selectedChapters.contains(index)) { + if (lastChapterIndex == -1) { + groups.add(mutableListOf()) } - currentGroup.add(pair) - lastChapterIndex = index - } - - if (currentGroup.isNotEmpty()) { - groups.add(currentGroup) + groups.last().add(snapChapter) + index + } else { + -1 + } } groups.forEach { group -> diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/util/media/HttpServer.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/util/media/HttpServer.kt @@ -1,10 +1,6 @@ package me.rhunk.snapenhance.core.util.media -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job -import kotlinx.coroutines.delay -import kotlinx.coroutines.launch +import kotlinx.coroutines.* import me.rhunk.snapenhance.common.logger.AbstractLogger import java.io.BufferedReader import java.io.InputStream @@ -16,12 +12,16 @@ import java.net.SocketException import java.util.Locale import java.util.StringTokenizer import java.util.concurrent.ConcurrentHashMap +import kotlin.coroutines.suspendCoroutine import kotlin.random.Random class HttpServer( private val timeout: Int = 10000 ) { - val port = Random.nextInt(10000, 65535) + private fun newRandomPort() = Random.nextInt(10000, 65535) + + var port = newRandomPort() + private set private val coroutineScope = CoroutineScope(Dispatchers.IO) private var timeoutJob: Job? = null @@ -30,42 +30,56 @@ class HttpServer( private val cachedData = ConcurrentHashMap<String, Pair<InputStream, Long>>() private var serverSocket: ServerSocket? = null - fun ensureServerStarted(callback: HttpServer.() -> Unit) { - if (serverSocket != null && !serverSocket!!.isClosed) { - callback(this) - return - } + fun ensureServerStarted(): HttpServer? { + if (serverSocket != null && serverSocket?.isClosed != true) return this + + return runBlocking { + withTimeoutOrNull(5000L) { + suspendCoroutine { continuation -> + coroutineScope.launch(Dispatchers.IO) { + AbstractLogger.directDebug("Starting http server on port $port") + for (i in 0..5) { + try { + serverSocket = ServerSocket(port) + break + } catch (e: Throwable) { + AbstractLogger.directError("failed to start http server on port $port", e) + port = newRandomPort() + } + } + continuation.resumeWith(Result.success(if (serverSocket == null) null.also { + return@launch + } else this@HttpServer)) - coroutineScope.launch(Dispatchers.IO) { - AbstractLogger.directDebug("starting http server on port $port") - serverSocket = ServerSocket(port) - callback(this@HttpServer) - while (!serverSocket!!.isClosed) { - try { - val socket = serverSocket!!.accept() - timeoutJob?.cancel() - launch { - handleRequest(socket) - timeoutJob = launch { - delay(timeout.toLong()) - AbstractLogger.directDebug("http server closed due to timeout") - runCatching { - socketJob?.cancel() - socket.close() - serverSocket?.close() - }.onFailure { - AbstractLogger.directError("failed to close socket", it) + while (!serverSocket!!.isClosed) { + try { + val socket = serverSocket!!.accept() + timeoutJob?.cancel() + launch { + handleRequest(socket) + timeoutJob = launch { + delay(timeout.toLong()) + AbstractLogger.directDebug("http server closed due to timeout") + runCatching { + socketJob?.cancel() + socket.close() + serverSocket?.close() + }.onFailure { + AbstractLogger.directError("failed to close socket", it) + } + } + } + } catch (e: SocketException) { + AbstractLogger.directDebug("http server timed out") + break; + } catch (e: Throwable) { + AbstractLogger.directError("failed to handle request", e) } } - } - } catch (e: SocketException) { - AbstractLogger.directDebug("http server timed out") - break; - } catch (e: Throwable) { - AbstractLogger.directError("failed to handle request", e) + }.also { socketJob = it } } } - }.also { socketJob = it } + } } fun close() { @@ -112,18 +126,15 @@ class HttpServer( if (fileRequested.startsWith("/")) { fileRequested = fileRequested.substring(1) } - if (!cachedData.containsKey(fileRequested)) { - with(writer) { - println("HTTP/1.1 404 Not Found") - println("Content-type: " + "application/octet-stream") - println("Content-length: " + 0) - println() - flush() - } + val requestedData = cachedData[fileRequested] ?: writer.run { + println("HTTP/1.1 404 Not Found") + println("Content-type: " + "application/octet-stream") + println("Content-length: " + 0) + println() + flush() close() return } - val requestedData = cachedData[fileRequested]!! with(writer) { println("HTTP/1.1 200 OK") println("Content-type: " + "application/octet-stream") @@ -131,9 +142,11 @@ class HttpServer( println() flush() } - requestedData.first.copyTo(outputStream) - outputStream.flush() cachedData.remove(fileRequested) + requestedData.first.use { + it.copyTo(outputStream) + } + outputStream.flush() close() } } \ No newline at end of file