commit ec0fd2cf08ed312188d2d73f5fc0bcfadbb7f38f
parent 8e36425a369a0fd6d621c5f71d6b2e1fddf59e2b
Author: rhunk <101876869+rhunk@users.noreply.github.com>
Date:   Wed, 17 Jan 2024 16:18:20 +0100

perf(e2ee): cache optimization
- shared secret key cache
- remove uncommitted messages from the cache

Diffstat:
Mapp/src/main/kotlin/me/rhunk/snapenhance/e2ee/E2EEImplementation.kt | 41++++++++++++++++++-----------------------
Mcore/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/experiments/EndToEndEncryption.kt | 17+++++++++++++----
2 files changed, 31 insertions(+), 27 deletions(-)

diff --git a/app/src/main/kotlin/me/rhunk/snapenhance/e2ee/E2EEImplementation.kt b/app/src/main/kotlin/me/rhunk/snapenhance/e2ee/E2EEImplementation.kt @@ -3,6 +3,7 @@ package me.rhunk.snapenhance.e2ee import me.rhunk.snapenhance.RemoteSideContext import me.rhunk.snapenhance.bridge.e2ee.E2eeInterface import me.rhunk.snapenhance.bridge.e2ee.EncryptionResult +import me.rhunk.snapenhance.core.util.EvictingMap import org.bouncycastle.pqc.crypto.crystals.kyber.* import java.io.File import java.security.MessageDigest @@ -23,25 +24,33 @@ class E2EEImplementation ( }} private val pairingFolder by lazy { File(context.androidContext.cacheDir, "e2ee-pairing").also { if (!it.exists()) it.mkdirs() + else { + it.deleteRecursively() + it.mkdirs() + } } } + private val sharedSecretKeyCache = EvictingMap<String, ByteArray?>(100) + fun storeSharedSecretKey(friendId: String, key: ByteArray) { File(e2eeFolder, "$friendId.key").writeBytes(key) + sharedSecretKeyCache[friendId] = key } fun getSharedSecretKey(friendId: String): ByteArray? { - return runCatching { - File(e2eeFolder, "$friendId.key").readBytes() - }.onFailure { - context.log.error("Failed to read shared secret key", it) - }.getOrNull() + return sharedSecretKeyCache.getOrPut(friendId) { + runCatching { + File(e2eeFolder, "$friendId.key").readBytes() + }.onFailure { + context.log.error("Failed to read shared secret key", it) + }.getOrNull() + } } fun deleteSharedSecretKey(friendId: String) { File(e2eeFolder, "$friendId.key").delete() } - override fun createKeyExchange(friendId: String): ByteArray? { val keyPairGenerator = KyberKeyPairGenerator() keyPairGenerator.init( @@ -117,12 +126,7 @@ class E2EEImplementation ( } override fun getSecretFingerprint(friendId: String): String? { - val sharedSecretKey = runCatching { - File(e2eeFolder, "$friendId.key").readBytes() - }.onFailure { - context.log.error("Failed to read shared secret key", it) - return null - }.getOrThrow() + val sharedSecretKey = getSharedSecretKey(friendId) ?: return null return MessageDigest.getInstance("SHA-256") .digest(sharedSecretKey) @@ -132,11 +136,7 @@ class E2EEImplementation ( } override fun encryptMessage(friendId: String, message: ByteArray): EncryptionResult? { - val encryptionKey = runCatching { - File(e2eeFolder, "$friendId.key").readBytes() - }.onFailure { - context.log.error("Failed to read shared secret key", it) - }.getOrNull() + val encryptionKey = getSharedSecretKey(friendId) ?: return null return runCatching { val iv = ByteArray(16).apply { secureRandom.nextBytes(this) } @@ -152,12 +152,7 @@ class E2EEImplementation ( } override fun decryptMessage(friendId: String, message: ByteArray, iv: ByteArray): ByteArray? { - val encryptionKey = runCatching { - File(e2eeFolder, "$friendId.key").readBytes() - }.onFailure { - context.log.error("Failed to read shared secret key", it) - return null - }.getOrNull() + val encryptionKey = getSharedSecretKey(friendId) ?: return null return runCatching { val cipher = Cipher.getInstance("AES/CBC/PKCS5Padding") diff --git a/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/experiments/EndToEndEncryption.kt b/core/src/main/kotlin/me/rhunk/snapenhance/core/features/impl/experiments/EndToEndEncryption.kt @@ -16,6 +16,7 @@ import androidx.compose.runtime.remember import androidx.compose.ui.Modifier import androidx.compose.ui.unit.dp import me.rhunk.snapenhance.common.data.ContentType +import me.rhunk.snapenhance.common.data.MessageState import me.rhunk.snapenhance.common.data.MessagingRuleType import me.rhunk.snapenhance.common.data.RuleState import me.rhunk.snapenhance.common.util.protobuf.ProtoEditor @@ -312,7 +313,7 @@ class EndToEndEncryption : MessagingRuleFeature( if (messageTypeId == ENCRYPTED_MESSAGE_ID) { runCatching { eachBuffer(2) { - if (encryptedMessages.contains(clientMessageId)) return@eachBuffer + if (decryptedMessageCache.containsKey(clientMessageId)) return@eachBuffer val participantIdHash = getByteArray(1) ?: return@eachBuffer val iv = getByteArray(2) ?: return@eachBuffer @@ -373,10 +374,15 @@ class EndToEndEncryption : MessagingRuleFeature( return outputContentType to outputBuffer } - private fun messageHook(conversationId: String, messageId: Long, senderId: String, messageContent: MessageContent) { + private fun messageHook(conversationId: String, messageId: Long, senderId: String, messageContent: MessageContent, committed: Boolean) { val (contentType, buffer) = tryDecryptMessage(senderId, messageId, conversationId, messageContent.contentType ?: ContentType.CHAT, messageContent.content!!) messageContent.contentType = contentType messageContent.content = buffer + // remove messages currently being sent from the cache + if (!committed) { + decryptedMessageCache.remove(messageId) + encryptedMessages.remove(messageId) + } } override fun asyncInit() { @@ -520,11 +526,13 @@ class EndToEndEncryption : MessagingRuleFeature( context.event.subscribe(BuildMessageEvent::class, priority = 0) { event -> val message = event.message val conversationId = message.messageDescriptor!!.conversationId.toString() + val isMessageCommitted = message.messageState == MessageState.COMMITTED messageHook( conversationId = conversationId, messageId = message.messageDescriptor!!.messageId!!, senderId = message.senderId.toString(), - messageContent = message.messageContent!! + messageContent = message.messageContent!!, + committed = isMessageCommitted ) message.messageContent!!.instanceNonNull() @@ -535,7 +543,8 @@ class EndToEndEncryption : MessagingRuleFeature( conversationId = conversationId, messageId = quotedMessage.getObjectField("mMessageId")?.toString()?.toLong() ?: return@also, senderId = SnapUUID(quotedMessage.getObjectField("mSenderId")).toString(), - messageContent = MessageContent(quotedMessage) + messageContent = MessageContent(quotedMessage), + committed = isMessageCommitted ) } }