commit f375f9bc0e1f6ca77e27473464e4b81af1d094ff
parent 4046d1a50658add099fb33396ff1a93fed012ac5
Author: rhunk <101876869+rhunk@users.noreply.github.com>
Date: Wed, 29 Nov 2023 23:56:55 +0100
fix(downloader): media identifier & dash chapter selector
- fix ffmpeg crashes
- fix close resources
- perf http server
Diffstat:
8 files changed, 144 insertions(+), 108 deletions(-)
diff --git a/app/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt b/app/src/main/kotlin/me/rhunk/snapenhance/download/DownloadProcessor.kt
@@ -122,10 +122,9 @@ class DownloadProcessor (
}
pendingTask.updateProgress("Converting image to $format")
- val outputStream = inputFile.outputStream()
- bitmap.compress(compressFormat, 100, outputStream)
- outputStream.close()
-
+ inputFile.outputStream().use {
+ bitmap.compress(compressFormat, 100, it)
+ }
fileType = FileType.fromFile(inputFile)
}
}
@@ -146,11 +145,12 @@ class DownloadProcessor (
}
val outputFile = outputFileFolder.createFile(fileType.mimeType, fileName)!!
- val outputStream = remoteSideContext.androidContext.contentResolver.openOutputStream(outputFile.uri)!!
pendingTask.updateProgress("Saving media to gallery")
- inputFile.inputStream().use { inputStream ->
- inputStream.copyTo(outputStream)
+ remoteSideContext.androidContext.contentResolver.openOutputStream(outputFile.uri)!!.use { outputStream ->
+ inputFile.inputStream().use { inputStream ->
+ inputStream.copyTo(outputStream)
+ }
}
pendingTask.task.extra = outputFile.uri.toString()
@@ -201,19 +201,20 @@ class DownloadProcessor (
fun handleInputStream(inputStream: InputStream, estimatedSize: Long = 0L) {
createMediaTempFile().apply {
val decryptedInputStream = (inputMedia.encryption?.decryptInputStream(inputStream) ?: inputStream).buffered()
- val outputStream = outputStream()
val buffer = ByteArray(DEFAULT_BUFFER_SIZE)
var read: Int
var totalRead = 0L
var lastTotalRead = 0L
- while (decryptedInputStream.read(buffer).also { read = it } != -1) {
- outputStream.write(buffer, 0, read)
- totalRead += read
- inputMediaDownloadedBytes[inputMedia] = totalRead
- if (totalRead - lastTotalRead > 1024 * 1024) {
- setProgress("${totalRead / 1024}KB/${estimatedSize / 1024}KB")
- lastTotalRead = totalRead
+ outputStream().use { outputStream ->
+ while (decryptedInputStream.read(buffer).also { read = it } != -1) {
+ outputStream.write(buffer, 0, read)
+ totalRead += read
+ inputMediaDownloadedBytes[inputMedia] = totalRead
+ if (totalRead - lastTotalRead > 1024 * 1024) {
+ setProgress("${totalRead / 1024}KB/${estimatedSize / 1024}KB")
+ lastTotalRead = totalRead
+ }
}
}
}.also { downloadedMedias[inputMedia] = it }
@@ -224,7 +225,9 @@ class DownloadProcessor (
DownloadMediaType.PROTO_MEDIA -> {
RemoteMediaResolver.downloadBoltMedia(Base64.UrlSafe.decode(inputMedia.content), decryptionCallback = { it }, resultCallback = { inputStream, length ->
totalSize += length
- handleInputStream(inputStream, estimatedSize = length)
+ inputStream.use {
+ handleInputStream(it, estimatedSize = length)
+ }
})
}
DownloadMediaType.REMOTE_MEDIA -> {
@@ -233,7 +236,9 @@ class DownloadProcessor (
setRequestProperty("User-Agent", Constants.USER_AGENT)
connect()
totalSize += contentLength.toLong()
- handleInputStream(inputStream, estimatedSize = contentLength.toLong())
+ inputStream.use {
+ handleInputStream(it, estimatedSize = contentLength.toLong())
+ }
}
}
DownloadMediaType.DIRECT_MEDIA -> {
@@ -292,8 +297,9 @@ class DownloadProcessor (
val dashOptions = downloadRequest.dashOptions!!
val dashPlaylistFile = renameFromFileType(media.file, FileType.MPD)
- val xmlData = dashPlaylistFile.outputStream()
- TransformerFactory.newInstance().newTransformer().transform(DOMSource(playlistXml), StreamResult(xmlData))
+ dashPlaylistFile.outputStream().use {
+ TransformerFactory.newInstance().newTransformer().transform(DOMSource(playlistXml), StreamResult(it))
+ }
callbackOnProgress(translation.format("download_toast", "path" to dashPlaylistFile.nameWithoutExtension))
val outputFile = File.createTempFile("dash", ".mp4")
@@ -329,9 +335,8 @@ class DownloadProcessor (
remoteSideContext.coroutineScope.launch {
val downloadMetadata = gson.fromJson(intent.getStringExtra(ReceiversConfig.DOWNLOAD_METADATA_EXTRA)!!, DownloadMetadata::class.java)
val downloadRequest = gson.fromJson(intent.getStringExtra(ReceiversConfig.DOWNLOAD_REQUEST_EXTRA)!!, DownloadRequest::class.java)
- val downloadId = (downloadMetadata.mediaIdentifier ?: UUID.randomUUID().toString()).longHashCode().absoluteValue.toString(16)
- remoteSideContext.taskManager.getTaskByHash(downloadId)?.let { task ->
+ remoteSideContext.taskManager.getTaskByHash(downloadMetadata.mediaIdentifier)?.let { task ->
remoteSideContext.log.debug("already queued or downloaded")
if (task.status.isFinalStage()) {
@@ -348,7 +353,7 @@ class DownloadProcessor (
Task(
type = TaskType.DOWNLOAD,
title = downloadMetadata.downloadSource + " (" + downloadMetadata.mediaAuthor + ")",
- hash = downloadId
+ hash = downloadMetadata.mediaIdentifier
)
).apply {
status = TaskStatus.RUNNING
@@ -372,15 +377,19 @@ class DownloadProcessor (
val oldDownloadedMedias = downloadedMedias.toMap()
downloadedMedias.clear()
- MediaDownloaderHelper.getSplitElements(zipFile.file.inputStream()) { type, inputStream ->
- createMediaTempFile().apply {
- inputStream.copyTo(outputStream())
- }.also {
- downloadedMedias[InputMedia(
- type = DownloadMediaType.LOCAL_MEDIA,
- content = it.absolutePath,
- isOverlay = type == SplitMediaAssetType.OVERLAY
- )] = DownloadedFile(it, FileType.fromFile(it))
+ zipFile.file.inputStream().use { zipFileInputStream ->
+ MediaDownloaderHelper.getSplitElements(zipFileInputStream) { type, inputStream ->
+ createMediaTempFile().apply {
+ outputStream().use {
+ inputStream.copyTo(it)
+ }
+ }.also {
+ downloadedMedias[InputMedia(
+ type = DownloadMediaType.LOCAL_MEDIA,
+ content = it.absolutePath,
+ isOverlay = type == SplitMediaAssetType.OVERLAY
+ )] = DownloadedFile(it, FileType.fromFile(it))
+ }
}
}
diff --git a/app/src/main/kotlin/me/rhunk/snapenhance/download/FFMpegProcessor.kt b/app/src/main/kotlin/me/rhunk/snapenhance/download/FFMpegProcessor.kt
@@ -94,6 +94,8 @@ class FFMpegProcessor(
}
suspend fun execute(args: Request) {
+ // load ffmpeg native sync to avoid native crash
+ synchronized(this) { FFmpegKit.listSessions() }
val globalArguments = ArgumentList().apply {
this += "-y"
this += "-threads" to ffmpegOptions.threads.get().toString()
diff --git a/common/src/main/kotlin/me/rhunk/snapenhance/common/config/impl/DownloaderConfig.kt b/common/src/main/kotlin/me/rhunk/snapenhance/common/config/impl/DownloaderConfig.kt
@@ -46,6 +46,6 @@ class DownloaderConfig : ConfigContainer() {
val chatDownloadContextMenu = boolean("chat_download_context_menu")
val ffmpegOptions = container("ffmpeg_options", FFMpegOptions()) { addNotices(FeatureNotice.UNSTABLE) }
val logging = multiple("logging", "started", "success", "progress", "failure").apply {
- set(mutableListOf("started", "success"))
+ set(mutableListOf("success", "progress", "failure"))
}
}
\ No newline at end of file
diff --git a/common/src/main/kotlin/me/rhunk/snapenhance/common/data/download/DownloadMetadata.kt b/common/src/main/kotlin/me/rhunk/snapenhance/common/data/download/DownloadMetadata.kt
@@ -1,7 +1,7 @@
package me.rhunk.snapenhance.common.data.download
data class DownloadMetadata(
- val mediaIdentifier: String?,
+ val mediaIdentifier: String,
val outputPath: String,
val mediaAuthor: String?,
val downloadSource: String,
diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/DownloadManagerClient.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/DownloadManagerClient.kt
@@ -26,9 +26,9 @@ class DownloadManagerClient (
DownloadRequest(
inputMedias = arrayOf(
InputMedia(
- content = playlistUrl,
- type = DownloadMediaType.REMOTE_MEDIA
- )
+ content = playlistUrl,
+ type = DownloadMediaType.REMOTE_MEDIA
+ )
),
dashOptions = DashOptions(offsetTime, duration),
flags = DownloadRequest.Flags.IS_DASH_PLAYLIST
diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/Stories.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/Stories.kt
@@ -17,8 +17,11 @@ class Stories : Feature("Stories", loadParams = FeatureLoadParams.ACTIVITY_CREAT
fun cancelRequest() {
runBlocking {
suspendCoroutine {
- context.httpServer.ensureServerStarted {
- event.url = "http://127.0.0.1:${context.httpServer.port}"
+ context.httpServer.ensureServerStarted()?.let { server ->
+ event.url = "http://127.0.0.1:${server.port}"
+ it.resumeWith(Result.success(Unit))
+ } ?: run {
+ event.canceled = true
it.resumeWith(Result.success(Unit))
}
}
diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/downloader/MediaDownloader.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/downloader/MediaDownloader.kt
@@ -25,6 +25,7 @@ import me.rhunk.snapenhance.common.data.download.MediaDownloadSource
import me.rhunk.snapenhance.common.data.download.SplitMediaAssetType
import me.rhunk.snapenhance.common.database.impl.ConversationMessage
import me.rhunk.snapenhance.common.database.impl.FriendInfo
+import me.rhunk.snapenhance.common.util.ktx.longHashCode
import me.rhunk.snapenhance.common.util.protobuf.ProtoReader
import me.rhunk.snapenhance.common.util.snap.BitmojiSelfie
import me.rhunk.snapenhance.common.util.snap.MediaDownloaderHelper
@@ -54,9 +55,11 @@ import java.io.ByteArrayInputStream
import java.nio.file.Paths
import java.text.SimpleDateFormat
import java.util.Locale
+import java.util.UUID
import kotlin.coroutines.suspendCoroutine
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi
+import kotlin.math.absoluteValue
private fun String.sanitizeForPath(): String {
return this.replace(" ", "_")
@@ -85,7 +88,11 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
downloadSource: MediaDownloadSource,
friendInfo: FriendInfo? = null
): DownloadManagerClient {
- val generatedHash = mediaIdentifier.hashCode().toString(16).replaceFirst("-", "")
+ val generatedHash = (
+ if (!context.config.downloader.allowDuplicate.get()) mediaIdentifier
+ else UUID.randomUUID().toString()
+ ).longHashCode().absoluteValue.toString(16)
+
val iconUrl = BitmojiSelfie.getBitmojiSelfie(friendInfo?.bitmojiSelfieId, friendInfo?.bitmojiAvatarId, BitmojiSelfie.BitmojiSelfieType.THREE_D)
val downloadLogging by context.config.downloader.logging
@@ -98,9 +105,7 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
return DownloadManagerClient(
context = context,
metadata = DownloadMetadata(
- mediaIdentifier = if (!context.config.downloader.allowDuplicate.get()) {
- generatedHash
- } else null,
+ mediaIdentifier = generatedHash,
mediaAuthor = mediaAuthor,
downloadSource = downloadSource.key,
iconUrl = iconUrl,
@@ -161,7 +166,7 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
finalPath.append(downloadSource.pathName).append("/")
}
if (pathFormat.contains("append_hash")) {
- appendFileName(hexHash)
+ appendFileName(hexHash.substring(0, hexHash.length.coerceAtMost(8)))
}
if (pathFormat.contains("append_source")) {
appendFileName(downloadSource.pathName)
@@ -228,10 +233,12 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
Uri.parse(path).let { uri ->
if (uri.scheme == "file") {
return@let suspendCoroutine<String> { continuation ->
- context.httpServer.ensureServerStarted {
+ context.httpServer.ensureServerStarted()?.let { server ->
val file = Paths.get(uri.path).toFile()
- val url = putDownloadableContent(file.inputStream(), file.length())
+ val url = server.putDownloadableContent(file.inputStream(), file.length())
continuation.resumeWith(Result.success(url))
+ } ?: run {
+ continuation.resumeWith(Result.failure(Exception("Failed to start http server")))
}
}
}
@@ -426,7 +433,12 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
setTitle("Download dash media")
setMultiChoiceItems(
chapters.map { "Segment ${prettyPrintTime(it.offset)} - ${prettyPrintTime(it.offset + (it.duration ?: 0))}" }.toTypedArray(),
- List(chapters.size) { index -> currentChapterIndex == index }.toBooleanArray()
+ List(chapters.size) { index ->
+ if (currentChapterIndex == index) {
+ selectedChapters.add(index)
+ true
+ } else false
+ }.toBooleanArray()
) { _, which, isChecked ->
if (isChecked) {
selectedChapters.add(which)
@@ -444,22 +456,19 @@ class MediaDownloader : MessagingRuleFeature("MediaDownloader", MessagingRuleTyp
}
setPositiveButton("Download") { _, _ ->
val groups = mutableListOf<MutableList<SnapChapterInfo>>()
- var currentGroup = mutableListOf<SnapChapterInfo>()
- var lastChapterIndex = -1
- //check for consecutive chapters
- chapters.filterIndexed { index, _ -> selectedChapters.contains(index) }
- .forEachIndexed { index, pair ->
- if (lastChapterIndex != -1 && index != lastChapterIndex + 1) {
- groups.add(currentGroup)
- currentGroup = mutableListOf()
+ var lastChapterIndex = -1
+ // group consecutive chapters
+ chapters.forEachIndexed { index, snapChapter ->
+ lastChapterIndex = if (selectedChapters.contains(index)) {
+ if (lastChapterIndex == -1) {
+ groups.add(mutableListOf())
}
- currentGroup.add(pair)
- lastChapterIndex = index
- }
-
- if (currentGroup.isNotEmpty()) {
- groups.add(currentGroup)
+ groups.last().add(snapChapter)
+ index
+ } else {
+ -1
+ }
}
groups.forEach { group ->
diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/util/media/HttpServer.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/util/media/HttpServer.kt
@@ -1,10 +1,6 @@
package me.rhunk.snapenhance.core.util.media
-import kotlinx.coroutines.CoroutineScope
-import kotlinx.coroutines.Dispatchers
-import kotlinx.coroutines.Job
-import kotlinx.coroutines.delay
-import kotlinx.coroutines.launch
+import kotlinx.coroutines.*
import me.rhunk.snapenhance.common.logger.AbstractLogger
import java.io.BufferedReader
import java.io.InputStream
@@ -16,12 +12,16 @@ import java.net.SocketException
import java.util.Locale
import java.util.StringTokenizer
import java.util.concurrent.ConcurrentHashMap
+import kotlin.coroutines.suspendCoroutine
import kotlin.random.Random
class HttpServer(
private val timeout: Int = 10000
) {
- val port = Random.nextInt(10000, 65535)
+ private fun newRandomPort() = Random.nextInt(10000, 65535)
+
+ var port = newRandomPort()
+ private set
private val coroutineScope = CoroutineScope(Dispatchers.IO)
private var timeoutJob: Job? = null
@@ -30,42 +30,56 @@ class HttpServer(
private val cachedData = ConcurrentHashMap<String, Pair<InputStream, Long>>()
private var serverSocket: ServerSocket? = null
- fun ensureServerStarted(callback: HttpServer.() -> Unit) {
- if (serverSocket != null && !serverSocket!!.isClosed) {
- callback(this)
- return
- }
+ fun ensureServerStarted(): HttpServer? {
+ if (serverSocket != null && serverSocket?.isClosed != true) return this
+
+ return runBlocking {
+ withTimeoutOrNull(5000L) {
+ suspendCoroutine { continuation ->
+ coroutineScope.launch(Dispatchers.IO) {
+ AbstractLogger.directDebug("Starting http server on port $port")
+ for (i in 0..5) {
+ try {
+ serverSocket = ServerSocket(port)
+ break
+ } catch (e: Throwable) {
+ AbstractLogger.directError("failed to start http server on port $port", e)
+ port = newRandomPort()
+ }
+ }
+ continuation.resumeWith(Result.success(if (serverSocket == null) null.also {
+ return@launch
+ } else this@HttpServer))
- coroutineScope.launch(Dispatchers.IO) {
- AbstractLogger.directDebug("starting http server on port $port")
- serverSocket = ServerSocket(port)
- callback(this@HttpServer)
- while (!serverSocket!!.isClosed) {
- try {
- val socket = serverSocket!!.accept()
- timeoutJob?.cancel()
- launch {
- handleRequest(socket)
- timeoutJob = launch {
- delay(timeout.toLong())
- AbstractLogger.directDebug("http server closed due to timeout")
- runCatching {
- socketJob?.cancel()
- socket.close()
- serverSocket?.close()
- }.onFailure {
- AbstractLogger.directError("failed to close socket", it)
+ while (!serverSocket!!.isClosed) {
+ try {
+ val socket = serverSocket!!.accept()
+ timeoutJob?.cancel()
+ launch {
+ handleRequest(socket)
+ timeoutJob = launch {
+ delay(timeout.toLong())
+ AbstractLogger.directDebug("http server closed due to timeout")
+ runCatching {
+ socketJob?.cancel()
+ socket.close()
+ serverSocket?.close()
+ }.onFailure {
+ AbstractLogger.directError("failed to close socket", it)
+ }
+ }
+ }
+ } catch (e: SocketException) {
+ AbstractLogger.directDebug("http server timed out")
+ break;
+ } catch (e: Throwable) {
+ AbstractLogger.directError("failed to handle request", e)
}
}
- }
- } catch (e: SocketException) {
- AbstractLogger.directDebug("http server timed out")
- break;
- } catch (e: Throwable) {
- AbstractLogger.directError("failed to handle request", e)
+ }.also { socketJob = it }
}
}
- }.also { socketJob = it }
+ }
}
fun close() {
@@ -112,18 +126,15 @@ class HttpServer(
if (fileRequested.startsWith("/")) {
fileRequested = fileRequested.substring(1)
}
- if (!cachedData.containsKey(fileRequested)) {
- with(writer) {
- println("HTTP/1.1 404 Not Found")
- println("Content-type: " + "application/octet-stream")
- println("Content-length: " + 0)
- println()
- flush()
- }
+ val requestedData = cachedData[fileRequested] ?: writer.run {
+ println("HTTP/1.1 404 Not Found")
+ println("Content-type: " + "application/octet-stream")
+ println("Content-length: " + 0)
+ println()
+ flush()
close()
return
}
- val requestedData = cachedData[fileRequested]!!
with(writer) {
println("HTTP/1.1 200 OK")
println("Content-type: " + "application/octet-stream")
@@ -131,9 +142,11 @@ class HttpServer(
println()
flush()
}
- requestedData.first.copyTo(outputStream)
- outputStream.flush()
cachedData.remove(fileRequested)
+ requestedData.first.use {
+ it.copyTo(outputStream)
+ }
+ outputStream.flush()
close()
}
}
\ No newline at end of file