JSModule.kt (13311B) - raw


      1 package me.rhunk.snapenhance.common.scripting
      2 
      3 import android.os.Handler
      4 import android.widget.Toast
      5 import kotlinx.coroutines.*
      6 import me.rhunk.snapenhance.common.scripting.bindings.AbstractBinding
      7 import me.rhunk.snapenhance.common.scripting.bindings.BindingsContext
      8 import me.rhunk.snapenhance.common.scripting.impl.JavaInterfaces
      9 import me.rhunk.snapenhance.common.scripting.impl.Networking
     10 import me.rhunk.snapenhance.common.scripting.impl.Protobuf
     11 import me.rhunk.snapenhance.common.scripting.ktx.contextScope
     12 import me.rhunk.snapenhance.common.scripting.ktx.putFunction
     13 import me.rhunk.snapenhance.common.scripting.ktx.scriptable
     14 import me.rhunk.snapenhance.common.scripting.ktx.scriptableObject
     15 import me.rhunk.snapenhance.common.scripting.type.ModuleInfo
     16 import me.rhunk.snapenhance.common.scripting.type.Permissions
     17 import me.rhunk.snapenhance.common.scripting.ui.InterfaceManager
     18 import org.mozilla.javascript.*
     19 import org.mozilla.javascript.Function
     20 import java.io.Reader
     21 import java.lang.reflect.Modifier
     22 import kotlin.reflect.KClass
     23 
     24 class JSModule(
     25     private val scriptRuntime: ScriptRuntime,
     26     val moduleInfo: ModuleInfo,
     27     private val reader: Reader,
     28 ) {
     29     val coroutineScope = CoroutineScope(Dispatchers.IO)
     30     private val moduleBindings = mutableMapOf<String, AbstractBinding>()
     31     private lateinit var moduleObject: ScriptableObject
     32 
     33     private val moduleBindingContext by lazy {
     34         BindingsContext(
     35             moduleInfo = moduleInfo,
     36             runtime = scriptRuntime,
     37             module = this,
     38         )
     39     }
     40 
     41     fun load(block: ScriptableObject.() -> Unit) {
     42         contextScope {
     43             val classLoader = scriptRuntime.androidContext.classLoader
     44             moduleObject = initSafeStandardObjects()
     45             moduleObject.putConst("module", moduleObject, scriptableObject {
     46                 putConst("info", this, scriptableObject {
     47                     putConst("name", this, moduleInfo.name)
     48                     putConst("version", this, moduleInfo.version)
     49                     putConst("displayName", this, moduleInfo.displayName)
     50                     putConst("description", this, moduleInfo.description)
     51                     putConst("author", this, moduleInfo.author)
     52                     putConst("minSnapchatVersion", this, moduleInfo.minSnapchatVersion)
     53                     putConst("minSEVersion", this, moduleInfo.minSEVersion)
     54                     putConst("grantedPermissions", this, moduleInfo.grantedPermissions)
     55                 })
     56             })
     57 
     58             scriptRuntime.logger.apply {
     59                 moduleObject.putConst("console", moduleObject, scriptableObject {
     60                     putFunction("log") { info(argsToString(it)) }
     61                     putFunction("warn") { warn(argsToString(it)) }
     62                     putFunction("error") { error(argsToString(it)) }
     63                     putFunction("debug") { debug(argsToString(it)) }
     64                     putFunction("info") { info(argsToString(it)) }
     65                     putFunction("trace") { verbose(argsToString(it)) }
     66                     putFunction("verbose") { verbose(argsToString(it)) }
     67                 })
     68             }
     69 
     70             registerBindings(
     71                 JavaInterfaces(),
     72                 InterfaceManager(),
     73                 Networking(),
     74                 Protobuf()
     75             )
     76 
     77             moduleObject.putFunction("setField") { args ->
     78                 val obj = args?.get(0) ?: return@putFunction Undefined.instance
     79                 val name = args[1].toString()
     80                 val value = args[2]
     81                 val field = obj.javaClass.declaredFields.find { it.name == name } ?: return@putFunction Undefined.instance
     82                 field.isAccessible = true
     83                 field.set(obj, value.toPrimitiveValue(lazy { field.type.name }))
     84                 Undefined.instance
     85             }
     86 
     87             moduleObject.putFunction("getField") { args ->
     88                 val obj = args?.get(0) ?: return@putFunction Undefined.instance
     89                 val name = args[1].toString()
     90                 val field = obj.javaClass.declaredFields.find { it.name == name } ?: return@putFunction Undefined.instance
     91                 field.isAccessible = true
     92                 field.get(obj)
     93             }
     94 
     95             moduleObject.putFunction("sleep") { args ->
     96                 val time = args?.get(0) as? Number ?: return@putFunction Undefined.instance
     97                 Thread.sleep(time.toLong())
     98                 Undefined.instance
     99             }
    100 
    101             moduleObject.putFunction("findClass") {
    102                 val className = it?.get(0).toString()
    103                 val useModClassLoader = it?.getOrNull(1) as? Boolean ?: false
    104                 if (useModClassLoader) moduleInfo.ensurePermissionGranted(Permissions.UNSAFE_CLASSLOADER)
    105 
    106                 runCatching {
    107                     if (useModClassLoader) this::class.java.classLoader?.loadClass(className)
    108                     else classLoader.loadClass(className)
    109                 }.onFailure { throwable ->
    110                     scriptRuntime.logger.error("Failed to load class $className", throwable)
    111                 }.getOrNull()
    112             }
    113 
    114             moduleObject.putFunction("type") { args ->
    115                 val className = args?.get(0).toString()
    116                 val useModClassLoader = args?.getOrNull(1) as? Boolean ?: false
    117                 if (useModClassLoader) moduleInfo.ensurePermissionGranted(Permissions.UNSAFE_CLASSLOADER)
    118 
    119                 val clazz = runCatching {
    120                     if (useModClassLoader) this::class.java.classLoader?.loadClass(className) else classLoader.loadClass(className)
    121                 }.getOrNull() ?: return@putFunction Undefined.instance
    122 
    123                 scriptableObject("JavaClassWrapper") {
    124                     val newInstance: (Array<out Any?>?) -> Any? = { args ->
    125                         val constructor = clazz.declaredConstructors.find {
    126                             (args ?: emptyArray()).isSameParameters(it.parameterTypes)
    127                         }?.also { it.isAccessible = true } ?: throw IllegalArgumentException("Constructor not found with args ${argsToString(args)}")
    128                         constructor.newInstance(*args ?: emptyArray())
    129                     }
    130                     putFunction("__new__") { newInstance(it) }
    131 
    132                     clazz.declaredMethods.filter { Modifier.isStatic(it.modifiers) }.forEach { method ->
    133                         putFunction(method.name) { args ->
    134                             val declaredMethod = clazz.declaredMethods.find {
    135                                 it.name == method.name && (args ?: emptyArray()).isSameParameters(it.parameterTypes)
    136                             }?.also { it.isAccessible = true } ?: throw IllegalArgumentException("Method ${method.name} not found with args ${argsToString(args)}")
    137                             declaredMethod.invoke(null, *args ?: emptyArray())
    138                         }
    139                     }
    140 
    141                     clazz.declaredFields.filter { Modifier.isStatic(it.modifiers) }.forEach { field ->
    142                         field.isAccessible = true
    143                         defineProperty(field.name, { field.get(null) }, { value -> field.set(null, value) }, 0)
    144                     }
    145 
    146                     if (get("newInstance") == null) {
    147                         putFunction("newInstance") { newInstance(it) }
    148                     }
    149                 }
    150             }
    151 
    152             moduleObject.putFunction("logInfo") { args ->
    153                 scriptRuntime.logger.info(argsToString(args))
    154                 Undefined.instance
    155             }
    156 
    157             moduleObject.putFunction("logError") { args ->
    158                 scriptRuntime.logger.error(argsToString(arrayOf(args?.get(0))), args?.getOrNull(1) as? Throwable ?: Throwable())
    159                 Undefined.instance
    160             }
    161 
    162             moduleObject.putFunction("setTimeout") {
    163                 val function = it?.get(0) as? Function ?: return@putFunction Undefined.instance
    164                 val time = it[1] as? Number ?: 0
    165 
    166                 return@putFunction coroutineScope.launch {
    167                     delay(time.toLong())
    168                     contextScope {
    169                         function.call(this, this@putFunction, this@putFunction, emptyArray())
    170                     }
    171                 }
    172             }
    173 
    174             moduleObject.putFunction("setInterval") {
    175                 val function = it?.get(0) as? Function ?: return@putFunction Undefined.instance
    176                 val time = it[1] as? Number ?: 0
    177 
    178                 return@putFunction coroutineScope.launch {
    179                     while (true) {
    180                         delay(time.toLong())
    181                         contextScope {
    182                             function.call(this, this@putFunction, this@putFunction, emptyArray())
    183                         }
    184                     }
    185                 }
    186             }
    187 
    188             arrayOf("clearInterval", "clearTimeout").forEach {
    189                 moduleObject.putFunction(it) { args ->
    190                     val job = args?.get(0) as? Job ?: return@putFunction Undefined.instance
    191                     runCatching {
    192                         job.cancel()
    193                     }
    194                     Undefined.instance
    195                 }
    196             }
    197 
    198             for (toastFunc in listOf("longToast", "shortToast")) {
    199                 moduleObject.putFunction(toastFunc) { args ->
    200                     Handler(scriptRuntime.androidContext.mainLooper).post {
    201                         Toast.makeText(
    202                             scriptRuntime.androidContext,
    203                             args?.joinToString(" ") ?: "",
    204                             if (toastFunc == "longToast") Toast.LENGTH_LONG else Toast.LENGTH_SHORT
    205                         ).show()
    206                     }
    207                     Undefined.instance
    208                 }
    209             }
    210 
    211             block(moduleObject)
    212 
    213             moduleBindings.forEach { (_, instance) ->
    214                 instance.context = moduleBindingContext
    215 
    216                 runCatching {
    217                     instance.onInit()
    218                 }.onFailure {
    219                     scriptRuntime.logger.error("Failed to init binding ${instance.name}", it)
    220                 }
    221             }
    222 
    223             moduleObject.putFunction("require") { args ->
    224                 val bindingName = args?.get(0).toString()
    225                 val (namespace, path) = bindingName.takeIf {
    226                     it.startsWith("@") && it.contains("/")
    227                 }?.let {
    228                     it.substring(1).substringBefore("/") to it.substringAfter("/")
    229                 } ?: (null to "")
    230 
    231                 when (namespace) {
    232                     "modules" -> scriptRuntime.getModuleByName(path)?.moduleObject?.scriptable("module")?.scriptable("exports")
    233                     else -> moduleBindings[bindingName]?.getObject()
    234                 }
    235             }
    236         }
    237 
    238         contextScope(shouldOptimize = true) {
    239             evaluateReader(moduleObject, reader, moduleInfo.name, 1, null)
    240         }
    241     }
    242 
    243     fun unload() {
    244         callFunction("module.onUnload")
    245         runCatching {
    246             coroutineScope.cancel("Module unloaded")
    247         }
    248         moduleBindings.entries.removeIf { (name, binding) ->
    249             runCatching {
    250                 binding.onDispose()
    251             }.onFailure {
    252                 scriptRuntime.logger.error("Failed to dispose binding $name", it)
    253             }
    254             true
    255         }
    256     }
    257 
    258     fun callFunction(name: String, vararg args: Any?) {
    259         contextScope {
    260             name.split(".").also { split ->
    261                 val function = split.dropLast(1).fold(moduleObject) { obj, key ->
    262                     obj.get(key, obj) as? ScriptableObject ?: return@contextScope Unit
    263                 }.get(split.last(), moduleObject) as? Function ?: return@contextScope Unit
    264 
    265                 runCatching {
    266                     function.call(this, moduleObject, moduleObject, args)
    267                 }.onFailure {
    268                     scriptRuntime.logger.error("Error while calling function $name", it)
    269                 }
    270             }
    271         }
    272     }
    273 
    274     fun registerBindings(vararg bindings: AbstractBinding) {
    275         bindings.forEach {
    276             moduleBindings[it.name] = it.apply {
    277                 context = moduleBindingContext
    278             }
    279         }
    280     }
    281 
    282     fun onBridgeConnected(reloaded: Boolean = false) {
    283         if (reloaded) {
    284             moduleBindings.values.forEach { binding ->
    285                 runCatching {
    286                     binding.onBridgeReloaded()
    287                 }.onFailure {
    288                     scriptRuntime.logger.error("Failed to call onBridgeConnected for binding ${binding.name}", it)
    289                 }
    290             }
    291         }
    292 
    293         callFunction("module.onBridgeConnected", reloaded)
    294     }
    295 
    296     @Suppress("UNCHECKED_CAST")
    297     fun <T : Any> getBinding(clazz: KClass<T>): T? {
    298         return moduleBindings.values.find { clazz.isInstance(it) } as? T
    299     }
    300 
    301     private fun argsToString(args: Array<out Any?>?): String {
    302         return args?.joinToString(" ") {
    303             when (it) {
    304                 is Wrapper -> it.unwrap().let { value ->
    305                     if (value is Throwable) value.message + "\n" + value.stackTraceToString()
    306                     else value.toString()
    307                 }
    308                 else -> it.toString()
    309             }
    310         } ?: "null"
    311     }
    312 }