瀏覽代碼

[core] Fix polymorphic serialization

Karlatemp 2 年之前
父節點
當前提交
4d9f6e88f9

+ 1 - 1
mirai-core-api/src/commonMain/kotlin/contact/announcement/OfflineAnnouncement.kt

@@ -84,7 +84,7 @@ public sealed interface OfflineAnnouncement : Announcement {
         }
         }
 
 
         internal object Serializer : KSerializer<OfflineAnnouncement> by OfflineAnnouncementImpl.serializer().map(
         internal object Serializer : KSerializer<OfflineAnnouncement> by OfflineAnnouncementImpl.serializer().map(
-            resultantDescriptor = OfflineAnnouncementImpl.serializer().descriptor.copy(SERIAL_NAME),
+            resultantDescriptor = OfflineAnnouncementImpl.serializer().descriptor,
             deserialize = { it },
             deserialize = { it },
             serialize = { it.safeCast<OfflineAnnouncementImpl>() ?: create(content, parameters).cast() }
             serialize = { it.safeCast<OfflineAnnouncementImpl>() ?: create(content, parameters).cast() }
         )
         )

+ 6 - 0
mirai-core-api/src/commonMain/kotlin/internal/message/MessageSerializersImpl.kt

@@ -38,6 +38,12 @@ public open class MessageSourceSerializerImpl(serialName: String) :
             Mirai.constructMessageSource(botId, kind, fromId, targetId, ids, time, internalIds, originalMessage)
             Mirai.constructMessageSource(botId, kind, fromId, targetId, ids, time, internalIds, originalMessage)
         }
         }
     ) {
     ) {
+
+    @MiraiInternalApi
+    public companion object {
+        public fun serialDataSerializer(): KSerializer<*> = SerialData.serializer()
+    }
+
     @SerialName(MessageSource.SERIAL_NAME)
     @SerialName(MessageSource.SERIAL_NAME)
     @Serializable
     @Serializable
     internal class SerialData(
     internal class SerialData(

+ 2 - 3
mirai-core-api/src/commonMain/kotlin/message/data/FileMessage.kt

@@ -25,7 +25,6 @@ import net.mamoe.mirai.message.code.CodableMessage
 import net.mamoe.mirai.message.data.visitor.MessageVisitor
 import net.mamoe.mirai.message.data.visitor.MessageVisitor
 import net.mamoe.mirai.utils.MiraiInternalApi
 import net.mamoe.mirai.utils.MiraiInternalApi
 import net.mamoe.mirai.utils.NotStableForInheritance
 import net.mamoe.mirai.utils.NotStableForInheritance
-import net.mamoe.mirai.utils.copy
 import net.mamoe.mirai.utils.map
 import net.mamoe.mirai.utils.map
 import kotlin.jvm.JvmMultifileClass
 import kotlin.jvm.JvmMultifileClass
 import kotlin.jvm.JvmName
 import kotlin.jvm.JvmName
@@ -111,9 +110,9 @@ public expect interface FileMessage : MessageContent, ConstrainSingle, CodableMe
 }
 }
 
 
 @MiraiInternalApi
 @MiraiInternalApi
-internal open class FallbackFileMessageSerializer constructor(serialName: String) :
+internal open class FallbackFileMessageSerializer :
     KSerializer<FileMessage> by Delegate.serializer().map(
     KSerializer<FileMessage> by Delegate.serializer().map(
-        Delegate.serializer().descriptor.copy(serialName),
+        Delegate.serializer().descriptor,
         serialize = { Delegate(id, internalId, name, size) },
         serialize = { Delegate(id, internalId, name, size) },
         deserialize = { Mirai.createFileMessage(id, internalId, name, size) },
         deserialize = { Mirai.createFileMessage(id, internalId, name, size) },
     ) {
     ) {

+ 1 - 1
mirai-core-api/src/commonMain/kotlin/message/data/RockPaperScissors.kt

@@ -102,7 +102,7 @@ public enum class RockPaperScissors(
     }
     }
 
 
     internal object Serializer : KSerializer<RockPaperScissors> by Surrogate.serializer().map(
     internal object Serializer : KSerializer<RockPaperScissors> by Surrogate.serializer().map(
-        resultantDescriptor = Surrogate.serializer().descriptor.copy(SERIAL_NAME),
+        resultantDescriptor = Surrogate.serializer().descriptor,
         deserialize = { valueOf(it.name) },
         deserialize = { valueOf(it.name) },
         serialize = { Surrogate(name) },
         serialize = { Surrogate(name) },
     ) {
     ) {

+ 1 - 1
mirai-core-api/src/commonMain/kotlin/message/data/UnsupportedMessage.kt

@@ -62,7 +62,7 @@ public interface UnsupportedMessage : MessageContent {
     }
     }
 
 
     public object Serializer : KSerializer<UnsupportedMessage> by Surrogate.serializer().map(
     public object Serializer : KSerializer<UnsupportedMessage> by Surrogate.serializer().map(
-        resultantDescriptor = Surrogate.serializer().descriptor.copy(SERIAL_NAME),
+        resultantDescriptor = Surrogate.serializer().descriptor,
         deserialize = { Mirai.createUnsupportedMessage(struct.hexToBytes()) },
         deserialize = { Mirai.createUnsupportedMessage(struct.hexToBytes()) },
         serialize = { Surrogate(struct.toUHexString("")) }
         serialize = { Surrogate(struct.toUHexString("")) }
     ) {
     ) {

+ 3 - 1
mirai-core-api/src/commonMain/kotlin/utils/DeviceInfo.kt

@@ -11,6 +11,7 @@ package net.mamoe.mirai.utils
 
 
 import io.ktor.utils.io.core.*
 import io.ktor.utils.io.core.*
 import kotlinx.serialization.KSerializer
 import kotlinx.serialization.KSerializer
+import kotlinx.serialization.SerialName
 import kotlinx.serialization.Serializable
 import kotlinx.serialization.Serializable
 import kotlinx.serialization.Transient
 import kotlinx.serialization.Transient
 import kotlinx.serialization.builtins.serializer
 import kotlinx.serialization.builtins.serializer
@@ -292,7 +293,7 @@ internal object DeviceInfoManager {
     )
     )
 
 
     private object DeviceInfoVersionSerializer : KSerializer<DeviceInfo.Version> by SerialData.serializer().map(
     private object DeviceInfoVersionSerializer : KSerializer<DeviceInfo.Version> by SerialData.serializer().map(
-        resultantDescriptor = SerialData.serializer().descriptor.copy("Version"),
+        resultantDescriptor = SerialData.serializer().descriptor,
         deserialize = {
         deserialize = {
             DeviceInfo.Version(incremental, release, codename, sdk)
             DeviceInfo.Version(incremental, release, codename, sdk)
         },
         },
@@ -300,6 +301,7 @@ internal object DeviceInfoManager {
             SerialData(incremental, release, codename, sdk)
             SerialData(incremental, release, codename, sdk)
         }
         }
     ) {
     ) {
+        @SerialName("Version")
         @Serializable
         @Serializable
         private class SerialData(
         private class SerialData(
             val incremental: ByteArray = "5891938".toByteArray(),
             val incremental: ByteArray = "5891938".toByteArray(),

+ 1 - 1
mirai-core-api/src/jvmBaseMain/kotlin/message/data/FileMessage.kt

@@ -120,5 +120,5 @@ public actual interface FileMessage : MessageContent, ConstrainSingle, CodableMe
     }
     }
 
 
     public actual object Serializer :
     public actual object Serializer :
-        KSerializer<FileMessage> by FallbackFileMessageSerializer(SERIAL_NAME) // not polymorphic
+        KSerializer<FileMessage> by FallbackFileMessageSerializer()
 }
 }

+ 1 - 1
mirai-core-api/src/nativeMain/kotlin/message/data/FileMessage.kt

@@ -104,5 +104,5 @@ public actual interface FileMessage : MessageContent, ConstrainSingle, CodableMe
     }
     }
 
 
     public actual object Serializer :
     public actual object Serializer :
-        KSerializer<FileMessage> by FallbackFileMessageSerializer(SERIAL_NAME) // not polymorphic
+        KSerializer<FileMessage> by FallbackFileMessageSerializer()
 }
 }

+ 2 - 2
mirai-core/src/commonMain/kotlin/message/data/audio.kt

@@ -163,7 +163,7 @@ internal class OnlineAudioImpl(
     }
     }
 
 
     object Serializer : KSerializer<OnlineAudioImpl> by Surrogate.serializer().map(
     object Serializer : KSerializer<OnlineAudioImpl> by Surrogate.serializer().map(
-        resultantDescriptor = Surrogate.serializer().descriptor.copy(OnlineAudio.SERIAL_NAME),
+        resultantDescriptor = Surrogate.serializer().descriptor,
         deserialize = {
         deserialize = {
             OnlineAudioImpl(
             OnlineAudioImpl(
                 filename = filename,
                 filename = filename,
@@ -251,7 +251,7 @@ internal class OfflineAudioImpl(
     }
     }
 
 
     object Serializer : KSerializer<OfflineAudioImpl> by Surrogate.serializer().map(
     object Serializer : KSerializer<OfflineAudioImpl> by Surrogate.serializer().map(
-        resultantDescriptor = Surrogate.serializer().descriptor.copy(OfflineAudio.SERIAL_NAME),
+        resultantDescriptor = Surrogate.serializer().descriptor,
         deserialize = {
         deserialize = {
             OfflineAudioImpl(
             OfflineAudioImpl(
                 filename = filename,
                 filename = filename,

+ 2 - 2
mirai-core/src/commonMain/kotlin/message/protocol/impl/QuoteReplyProtocol.kt

@@ -10,6 +10,7 @@
 package net.mamoe.mirai.internal.message.protocol.impl
 package net.mamoe.mirai.internal.message.protocol.impl
 
 
 import net.mamoe.mirai.contact.AnonymousMember
 import net.mamoe.mirai.contact.AnonymousMember
+import net.mamoe.mirai.internal.message.MessageSourceSerializerImpl
 import net.mamoe.mirai.internal.message.protocol.MessageProtocol
 import net.mamoe.mirai.internal.message.protocol.MessageProtocol
 import net.mamoe.mirai.internal.message.protocol.ProcessorCollector
 import net.mamoe.mirai.internal.message.protocol.ProcessorCollector
 import net.mamoe.mirai.internal.message.protocol.decode.MessageDecoder
 import net.mamoe.mirai.internal.message.protocol.decode.MessageDecoder
@@ -24,7 +25,6 @@ import net.mamoe.mirai.internal.message.protocol.serialization.MessageSerializer
 import net.mamoe.mirai.internal.message.source.*
 import net.mamoe.mirai.internal.message.source.*
 import net.mamoe.mirai.internal.network.protocol.data.proto.ImMsgBody
 import net.mamoe.mirai.internal.network.protocol.data.proto.ImMsgBody
 import net.mamoe.mirai.message.data.*
 import net.mamoe.mirai.message.data.*
-import net.mamoe.mirai.utils.copy
 import net.mamoe.mirai.utils.map
 import net.mamoe.mirai.utils.map
 
 
 internal class QuoteReplyProtocol : MessageProtocol(PRIORITY_METADATA) {
 internal class QuoteReplyProtocol : MessageProtocol(PRIORITY_METADATA) {
@@ -100,7 +100,7 @@ internal class QuoteReplyProtocol : MessageProtocol(PRIORITY_METADATA) {
                 MessageSerializer(
                 MessageSerializer(
                     MessageSource::class,
                     MessageSource::class,
                     OfflineMessageSourceImplData.serializer().map(
                     OfflineMessageSourceImplData.serializer().map(
-                        OfflineMessageSourceImplData.serializer().descriptor.copy(MessageSource.SERIAL_NAME),
+                        MessageSourceSerializerImpl.serialDataSerializer().descriptor,
                         { it },
                         { it },
                         {
                         {
                             OfflineMessageSourceImplData(
                             OfflineMessageSourceImplData(

+ 34 - 0
mirai-core/src/commonTest/kotlin/message/protocol/impl/MarketFaceProtocolTest.kt

@@ -564,6 +564,11 @@ internal class MarketFaceProtocolTest : AbstractMessageProtocolTest() {
         override val message: Dice
         override val message: Dice
     ) : PolymorphicWrapper
     ) : PolymorphicWrapper
 
 
+    @Serializable
+    data class StaticWrapperRockPaperScissors(
+        override val message: RockPaperScissors
+    ) : PolymorphicWrapper
+
     private fun <M : MarketFace> testPolymorphicInMarketFace(
     private fun <M : MarketFace> testPolymorphicInMarketFace(
         data: M,
         data: M,
         expectedSerialName: String,
         expectedSerialName: String,
@@ -591,6 +596,19 @@ internal class MarketFaceProtocolTest : AbstractMessageProtocolTest() {
         )
         )
     })
     })
 
 
+    private fun testStaticRockPaperScissors(
+        data: RockPaperScissors,
+        expectedInstance: RockPaperScissors = data,
+    ) = listOf(dynamicTest("testStaticRockPaperScissors") {
+        testPolymorphicIn(
+            polySerializer = StaticWrapperRockPaperScissors.serializer(),
+            polyConstructor = ::StaticWrapperRockPaperScissors,
+            data = data,
+            expectedSerialName = null,
+            expectedInstance = expectedInstance,
+        )
+    })
+
     @TestFactory
     @TestFactory
     fun `test serialization for MarketFaceImpl`(): DynamicTestsResult {
     fun `test serialization for MarketFaceImpl`(): DynamicTestsResult {
         val data = MarketFaceImpl(
         val data = MarketFaceImpl(
@@ -617,6 +635,22 @@ internal class MarketFaceProtocolTest : AbstractMessageProtocolTest() {
         )
         )
     }
     }
 
 
+    @TestFactory
+    fun `test serialization for RockPaperScissors`(): DynamicTestsResult {
+        val data = RockPaperScissors.PAPER
+
+        val serialName = RockPaperScissors.SERIAL_NAME
+        return runDynamicTests(
+            testPolymorphicInMarketFace(data, serialName),
+            testPolymorphicInMessageContent(data, serialName),
+            testPolymorphicInSingleMessage(data, serialName),
+            testInsideMessageChain(data, serialName),
+            testContextual(data, serialName),
+            testContextual(data, serialName, targetType = MarketFace::class),
+            testStaticRockPaperScissors(data),
+        )
+    }
+
     @TestFactory
     @TestFactory
     fun `test serialization for Dice`(): DynamicTestsResult {
     fun `test serialization for Dice`(): DynamicTestsResult {
         val data = Dice(1)
         val data = Dice(1)

+ 21 - 12
mirai-core/src/commonTest/kotlin/message/protocol/impl/QuoteReplyProtocolTest.kt

@@ -9,8 +9,8 @@
 
 
 package net.mamoe.mirai.internal.message.protocol.impl
 package net.mamoe.mirai.internal.message.protocol.impl
 
 
+import kotlinx.serialization.Polymorphic
 import kotlinx.serialization.Serializable
 import kotlinx.serialization.Serializable
-import kotlinx.serialization.json.Json
 import net.mamoe.mirai.internal.message.protocol.MessageProtocol
 import net.mamoe.mirai.internal.message.protocol.MessageProtocol
 import net.mamoe.mirai.internal.message.source.OfflineMessageSourceImplData
 import net.mamoe.mirai.internal.message.source.OfflineMessageSourceImplData
 import net.mamoe.mirai.internal.message.toMessageChainOnline
 import net.mamoe.mirai.internal.message.toMessageChainOnline
@@ -19,7 +19,6 @@ import net.mamoe.mirai.internal.testFramework.TestFactory
 import net.mamoe.mirai.internal.testFramework.dynamicTest
 import net.mamoe.mirai.internal.testFramework.dynamicTest
 import net.mamoe.mirai.internal.testFramework.runDynamicTests
 import net.mamoe.mirai.internal.testFramework.runDynamicTests
 import net.mamoe.mirai.internal.utils.runCoroutineInPlace
 import net.mamoe.mirai.internal.utils.runCoroutineInPlace
-import net.mamoe.mirai.message.MessageSerializers
 import net.mamoe.mirai.message.data.*
 import net.mamoe.mirai.message.data.*
 import net.mamoe.mirai.message.data.MessageSource.Key.quote
 import net.mamoe.mirai.message.data.MessageSource.Key.quote
 import net.mamoe.mirai.utils.EMPTY_BYTE_ARRAY
 import net.mamoe.mirai.utils.EMPTY_BYTE_ARRAY
@@ -488,18 +487,13 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
     ///////////////////////////////////////////////////////////////////////////
     ///////////////////////////////////////////////////////////////////////////
 
 
 
 
-    // TODO: 2022/7/20 MessageSource 在 MessageMetadata 的 scope 多态序列化后会输出 'type' = 'MessageSource', 这是期望的行为.
-    //  但是在反序列化时会错误 unknown field 'type'
-    override val format: Json
-        get() = Json {
-            prettyPrint = true
-            serializersModule = MessageSerializers.serializersModule
-            ignoreUnknownKeys = true
-        }
-
-
     @Serializable
     @Serializable
     data class PolymorphicWrapperMessageSource(
     data class PolymorphicWrapperMessageSource(
+        override val message: @Polymorphic MessageSource
+    ) : PolymorphicWrapper
+
+    @Serializable
+    data class StaticWrapperMessageSource(
         override val message: MessageSource
         override val message: MessageSource
     ) : PolymorphicWrapper
     ) : PolymorphicWrapper
 
 
@@ -516,6 +510,19 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
         )
         )
     })
     })
 
 
+    private fun <M : MessageSource> testStaticInMessageSource(
+        data: M,
+        expectedInstance: M = data,
+    ) = listOf(dynamicTest("testStaticInMessageSource") {
+        testPolymorphicIn(
+            polySerializer = StaticWrapperMessageSource.serializer(),
+            polyConstructor = ::StaticWrapperMessageSource,
+            data = data,
+            expectedInstance = expectedInstance,
+            expectedSerialName = null,
+        )
+    })
+
     @TestFactory
     @TestFactory
     fun `test serialization for OfflineMessageSource`(): DynamicTestsResult {
     fun `test serialization for OfflineMessageSource`(): DynamicTestsResult {
         val data = MessageSourceBuilder()
         val data = MessageSourceBuilder()
@@ -533,6 +540,7 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
             testPolymorphicInSingleMessage(data, serialName),
             testPolymorphicInSingleMessage(data, serialName),
             testInsideMessageChain(data, serialName),
             testInsideMessageChain(data, serialName),
             testContextual(data, serialName),
             testContextual(data, serialName),
+            testStaticInMessageSource(data),
         )
         )
     }
     }
 
 
@@ -548,6 +556,7 @@ internal class QuoteReplyProtocolTest : AbstractMessageProtocolTest() {
             testPolymorphicInSingleMessage(data, serialName, expectedInstance = expected),
             testPolymorphicInSingleMessage(data, serialName, expectedInstance = expected),
             testInsideMessageChain(data, serialName, expectedInstance = expected),
             testInsideMessageChain(data, serialName, expectedInstance = expected),
             testContextual(data, serialName, expectedInstance = expected),
             testContextual(data, serialName, expectedInstance = expected),
+            testStaticInMessageSource(data, expectedInstance = expected),
         )
         )
     }
     }
 }
 }