2
0
Эх сурвалжийг харах

Add fallback solution for Services (#2511)

* add fallback solution for Services

* use castUp

* throw exception when prop doesn't match

* cannot use castUp

* improve codes

* improve name of functions

* add both

* add override

* solve conflicts

* [core] Move MiraiCoreServices to common

* [core] Improvement

* update var names

* update func names

---------

Co-authored-by: Karlatemp <kar@kasukusakura.com>
Eritque arcus 2 жил өмнө
parent
commit
3ff2737b3c

+ 67 - 0
mirai-core-utils/src/commonMain/kotlin/Services.kt

@@ -11,9 +11,76 @@
 
 package net.mamoe.mirai.utils
 
+import kotlinx.atomicfu.locks.reentrantLock
+import kotlinx.atomicfu.locks.withLock
 import kotlin.jvm.JvmName
 import kotlin.reflect.KClass
 
+public object Services {
+    private val lock = reentrantLock()
+    public fun <T : Any> qualifiedNameOrFail(clazz: KClass<out T>): String =
+        clazz.qualifiedName ?: error("Could not find qualifiedName for $clazz")
+
+    internal class Implementation(
+        val implementationClass: String,
+        val instance: Lazy<Any>
+    )
+
+    private val registered: MutableMap<String, MutableList<Implementation>> = mutableMapOf()
+    private val overrided: MutableMap<String, Implementation> = mutableMapOf()
+
+    @Suppress("UNCHECKED_CAST")
+    public fun <T : Any> getOverrideOrNull(clazz: KClass<out T>): T? {
+        lock.withLock {
+            return overrided[qualifiedNameOrFail(clazz)]?.instance?.value as T?
+        }
+    }
+
+    internal fun registerAsOverride(baseClass: String, implementationClass: String, implementation: () -> Any) {
+        lock.withLock {
+            overrided[baseClass] = Implementation(implementationClass, lazy(implementation))
+        }
+    }
+
+    public fun register(baseClass: String, implementationClass: String, implementation: () -> Any) {
+        lock.withLock {
+            registered.getOrPut(baseClass, ::mutableListOf)
+                .add(Implementation(implementationClass, lazy(implementation)))
+        }
+    }
+
+    public fun firstImplementationOrNull(baseClass: String): Any? {
+        lock.withLock {
+            overrided[baseClass]?.let { return it.instance.value }
+            return registered[baseClass]?.firstOrNull()?.instance?.value
+        }
+    }
+
+    public fun implementations(baseClass: String): Sequence<Lazy<Any>>? {
+        lock.withLock {
+            val implementations = registered[baseClass]
+            val forced = overrided[baseClass]
+            if (forced == null && implementations == null) return null
+
+            val implementationsSnapshot = implementations?.toList().orEmpty()
+
+            return sequence {
+                if (forced != null) yield(forced.instance)
+
+                implementationsSnapshot.forEach { yield(it.instance) }
+            }
+        }
+    }
+
+    internal fun implementationsDirectly(baseClass: String) = lock.withLock { registered[baseClass]?.toList().orEmpty() }
+
+    public fun print(): String {
+        lock.withLock {
+            return registered.entries.joinToString { "${it.key}:${it.value}" }
+        }
+    }
+}
+
 public expect fun <T : Any> loadServiceOrNull(clazz: KClass<out T>, fallbackImplementation: String? = null): T?
 public expect fun <T : Any> loadService(clazz: KClass<out T>, fallbackImplementation: String? = null): T
 public expect fun <T : Any> loadServices(clazz: KClass<out T>): Sequence<T>

+ 77 - 12
mirai-core-utils/src/jvmBaseMain/kotlin/Services.kt

@@ -13,15 +13,51 @@ import java.util.*
 import kotlin.reflect.KClass
 import kotlin.reflect.full.createInstance
 
+private enum class LoaderType {
+    JDK,
+    BOTH,
+    FALLBACK,
+}
+
+private val loaderType = when (systemProp("mirai.service.loader", "both")) {
+    "jdk" -> LoaderType.JDK
+    "both" -> LoaderType.BOTH
+    "fallback" -> LoaderType.FALLBACK
+    else -> throw IllegalStateException("cannot find a service loader, mirai.service.loader must be both, jdk or fallback (default by both)")
+}
+
+@Suppress("UNCHECKED_CAST")
 public actual fun <T : Any> loadService(clazz: KClass<out T>, fallbackImplementation: String?): T {
+    val fallbackService by lazy {
+        Services.firstImplementationOrNull(Services.qualifiedNameOrFail(clazz)) as T?
+    }
+
+    val jdkService by lazy {
+        ServiceLoader.load(clazz.java).firstOrNull()?.let { return@lazy it }
+
+        ServiceLoader.load(clazz.java, clazz.java.classLoader).firstOrNull()
+    }
+
     var suppressed: Throwable? = null
-    return ServiceLoader.load(clazz.java).firstOrNull()
-        ?: ServiceLoader.load(clazz.java, clazz.java.classLoader).firstOrNull()
-        ?: (if (fallbackImplementation == null) null
-        else runCatching { findCreateInstance<T>(fallbackImplementation) }.onFailure { suppressed = it }.getOrNull())
-        ?: throw NoSuchElementException("Could not find an implementation for service class ${clazz.qualifiedName}").apply {
-            if (suppressed != null) addSuppressed(suppressed)
-        }
+
+    val services by lazy {
+        when (loaderType) {
+            LoaderType.JDK -> jdkService
+            LoaderType.BOTH -> jdkService ?: fallbackService
+            LoaderType.FALLBACK -> fallbackService
+        }?.let { return@lazy it }
+
+        if (fallbackImplementation != null) {
+            runCatching {
+                findCreateInstance<T>(fallbackImplementation)
+            }.onFailure { suppressed = it }.getOrNull()
+        } else null
+    }
+
+    return Services.getOverrideOrNull(clazz) ?: services
+    ?: throw NoSuchElementException("Could not find an implementation for service class ${clazz.qualifiedName}").apply {
+        if (suppressed != null) addSuppressed(suppressed)
+    }
 }
 
 private fun <T : Any> findCreateInstance(fallbackImplementation: String): T {
@@ -29,14 +65,17 @@ private fun <T : Any> findCreateInstance(fallbackImplementation: String): T {
 }
 
 public actual fun <T : Any> loadServiceOrNull(clazz: KClass<out T>, fallbackImplementation: String?): T? {
-    return ServiceLoader.load(clazz.java).firstOrNull()
-        ?: ServiceLoader.load(clazz.java, clazz.java.classLoader).firstOrNull()
-        ?: if (fallbackImplementation == null) return null
-        else runCatching { findCreateInstance<T>(fallbackImplementation) }.getOrNull()
+    return runCatching { loadService(clazz, fallbackImplementation) }.getOrNull()
 }
 
+@Suppress("UNCHECKED_CAST")
 public actual fun <T : Any> loadServices(clazz: KClass<out T>): Sequence<T> {
-    return sequence {
+    fun fallBackServicesSeq(): Sequence<T> {
+        return Services.implementations(Services.qualifiedNameOrFail(clazz)).orEmpty()
+            .map { it.value as T }
+    }
+
+    fun jdkServices(): Sequence<T> = sequence {
         val current = ServiceLoader.load(clazz.java).iterator()
         if (current.hasNext()) {
             yieldAll(current)
@@ -44,4 +83,30 @@ public actual fun <T : Any> loadServices(clazz: KClass<out T>): Sequence<T> {
             yieldAll(ServiceLoader.load(clazz.java, clazz.java.classLoader))
         }
     }
+
+    fun bothServices(): Sequence<T> = sequence {
+        Services.getOverrideOrNull(clazz)?.let { yield(it) }
+
+        var jdkServices = ServiceLoader.load(clazz.java).toList()
+        if (jdkServices.isEmpty()) {
+            jdkServices = ServiceLoader.load(clazz.java, clazz.java.classLoader).toList()
+        }
+        yieldAll(jdkServices)
+
+        Services.implementationsDirectly(Services.qualifiedNameOrFail(clazz)).asSequence()
+            .filter { impl ->
+                // Drop duplicated
+                jdkServices.none { it.javaClass.name == impl.implementationClass }
+            }
+            .forEach { yield(it.instance.value as T) }
+    }
+
+
+
+
+    return when (loaderType) {
+        LoaderType.JDK -> jdkServices()
+        LoaderType.BOTH -> bothServices()
+        LoaderType.FALLBACK -> fallBackServicesSeq()
+    }
 }

+ 2 - 43
mirai-core-utils/src/nativeMain/kotlin/Service.kt

@@ -11,47 +11,9 @@
 
 package net.mamoe.mirai.utils
 
-import kotlinx.atomicfu.locks.reentrantLock
-import kotlinx.atomicfu.locks.withLock
+import net.mamoe.mirai.utils.Services.qualifiedNameOrFail
 import kotlin.reflect.KClass
 
-public object Services {
-    private val lock = reentrantLock()
-
-    private class Implementation(
-        val implementationClass: String,
-        val instance: Lazy<Any>
-    )
-
-    private val registered: MutableMap<String, MutableList<Implementation>> = mutableMapOf()
-
-    public fun register(baseClass: String, implementationClass: String, implementation: () -> Any) {
-        lock.withLock {
-            registered.getOrPut(baseClass, ::mutableListOf)
-                .add(Implementation(implementationClass, lazy(implementation)))
-        }
-    }
-
-    public fun firstImplementationOrNull(baseClass: String): Any? {
-        lock.withLock {
-            return registered[baseClass]?.firstOrNull()?.instance?.value
-        }
-    }
-
-    public fun implementations(baseClass: String): List<Lazy<Any>>? {
-        lock.withLock {
-            return registered[baseClass]?.map { it.instance }
-        }
-
-    }
-
-    public fun print(): String {
-        lock.withLock {
-            return registered.entries.joinToString { "${it.key}:${it.value}" }
-        }
-    }
-}
-
 @Suppress("UNCHECKED_CAST")
 public actual fun <T : Any> loadServiceOrNull(
     clazz: KClass<out T>,
@@ -66,7 +28,4 @@ public actual fun <T : Any> loadService(
     ?: error("Could not load service '${clazz.qualifiedName ?: clazz}'. Current services: ${Services.print()}")
 
 public actual fun <T : Any> loadServices(clazz: KClass<out T>): Sequence<T> =
-    Services.implementations(qualifiedNameOrFail(clazz))?.asSequence()?.map { it.value }.orEmpty().castUp()
-
-private fun <T : Any> qualifiedNameOrFail(clazz: KClass<out T>) =
-    clazz.qualifiedName ?: error("Could not find qualifiedName for $clazz")
+    Services.implementations(qualifiedNameOrFail(clazz)).orEmpty().map { it.value }.castUp()

+ 0 - 0
mirai-core/src/nativeMain/kotlin/utils/MiraiCoreServices.kt → mirai-core/src/commonMain/kotlin/utils/MiraiCoreServices.kt


+ 2 - 0
mirai-core/src/jvmBaseMain/kotlin/MiraiImpl.kt

@@ -16,12 +16,14 @@ import io.ktor.client.engine.okhttp.*
 import io.ktor.client.plugins.*
 import kotlinx.atomicfu.atomic
 import net.mamoe.mirai.internal.message.protocol.MessageProtocolFacade
+import net.mamoe.mirai.internal.utils.MiraiCoreServices
 
 private val initialized = atomic(false)
 
 @Suppress("FunctionName")
 internal actual fun _MiraiImpl_static_init() {
     if (!initialized.compareAndSet(expect = false, update = true)) return
+    MiraiCoreServices.registerAll()
     MessageProtocolFacade.INSTANCE // register serializers
 }