Ver código fonte

Improve performance of `hexToBytes` and simplify similar extensions

Him188 4 anos atrás
pai
commit
dd606c3022

+ 1 - 1
mirai-core-api/src/commonMain/kotlin/message/data/UnsupportedMessage.kt

@@ -52,7 +52,7 @@ public interface UnsupportedMessage : MessageContent {
 
 
     public object Serializer : KSerializer<UnsupportedMessage> by Surrogate.serializer().map(
     public object Serializer : KSerializer<UnsupportedMessage> by Surrogate.serializer().map(
         resultantDescriptor = Surrogate.serializer().descriptor.copy(SERIAL_NAME),
         resultantDescriptor = Surrogate.serializer().descriptor.copy(SERIAL_NAME),
-        deserialize = { Mirai.createUnsupportedMessage(struct.chunkedHexToBytes()) },
+        deserialize = { Mirai.createUnsupportedMessage(struct.hexToBytes()) },
         serialize = { Surrogate(struct.toUHexString("")) }
         serialize = { Surrogate(struct.toUHexString("")) }
     ) {
     ) {
         @Suppress("RemoveRedundantQualifierName")
         @Suppress("RemoveRedundantQualifierName")

+ 98 - 36
mirai-core-utils/src/commonMain/kotlin/Conversions.kt

@@ -115,41 +115,6 @@ public fun Byte.fixToUHex(): String = this.toUByte().fixToUHex()
 public fun UByte.fixToUHex(): String =
 public fun UByte.fixToUHex(): String =
     if (this.toInt() in 0..15) "0${this.toString(16).uppercase()}" else this.toString(16).uppercase()
     if (this.toInt() in 0..15) "0${this.toString(16).uppercase()}" else this.toString(16).uppercase()
 
 
-public fun String.hexToBytes(): ByteArray =
-    this.split(" ")
-        .asSequence()
-        .filterNot { it.isEmpty() }
-        .map { s -> s.toUByte(16).toByte() }
-        .toList()
-        .toByteArray()
-
-/**
- * 每 2 char 为一组, 转换 Hex 为 [ByteArray]
- *
- * 这个方法很累, 不建议经常使用.
- */
-public fun String.chunkedHexToBytes(): ByteArray =
-    this.asSequence().chunked(2).map { (it[0].toString() + it[1]).toUByte(16).toByte() }.toList().toByteArray()
-
-/**
- * 删掉全部空格和换行后每 2 char 为一组, 转换 Hex 为 [ByteArray].
- */
-public fun String.autoHexToBytes(): ByteArray =
-    this.replace("\n", "").replace(" ", "").asSequence().chunked(2).map {
-        (it[0].toString() + it[1]).toUByte(16).toByte()
-    }.toList().toByteArray()
-
-/**
- * 将无符号 Hex 转为 [UByteArray], 有根据 hex 的 [hashCode] 建立的缓存.
- */
-public fun String.hexToUBytes(): UByteArray =
-    this.split(" ")
-        .asSequence()
-        .filterNot { it.isEmpty() }
-        .map { s -> s.toUByte(16) }
-        .toList()
-        .toUByteArray()
-
 /**
 /**
  * 将 [this] 前 4 个 [Byte] 的 bits 合并为一个 [Int]
  * 将 [this] 前 4 个 [Byte] 的 bits 合并为一个 [Int]
  *
  *
@@ -171,4 +136,101 @@ public fun ByteArray.toInt(): Int =
     (this[0].toInt().and(255) shl 24) + (this[1].toInt().and(255) shl 16) + (this[2].toInt()
     (this[0].toInt().and(255) shl 24) + (this[1].toInt().and(255) shl 16) + (this[2].toInt()
         .and(255) shl 8) + (this[3].toInt().and(
         .and(255) shl 8) + (this[3].toInt().and(
         255
         255
-    ) shl 0)
+    ) shl 0)
+
+
+///////////////////////////////////////////////////////////////////////////
+// hexToBytes
+///////////////////////////////////////////////////////////////////////////
+
+
+private val byteStringCandidates = arrayOf('a'..'f', 'A'..'F', '0'..'9', ' '..' ')
+private const val CHUNK_SPACE = -1
+
+public fun String.hexToBytes(): ByteArray {
+    val array = ByteArray(countHexBytes())
+    forEachHexChunkIndexed { index, char1, char2 ->
+        array[index] = Byte.parseFromHexChunk(char1, char2)
+    }
+    return array
+}
+
+public fun String.hexToUBytes(): UByteArray {
+    val array = UByteArray(countHexBytes())
+    forEachHexChunkIndexed { index, char1, char2 ->
+        array[index] = Byte.parseFromHexChunk(char1, char2).toUByte()
+    }
+    return array
+}
+
+public fun Byte.Companion.parseFromHexChunk(char1: Char, char2: Char): Byte {
+    return (char1.digitToInt(16).shl(SIZE_BITS / 2) or char2.digitToInt(16)).toByte()
+}
+
+private inline fun String.forEachHexChunkIndexed(block: (index: Int, char1: Char, char2: Char) -> Unit) {
+    var index = 0
+    forEachHexChunk { char1: Char, char2: Char ->
+        block(index++, char1, char2)
+    }
+}
+
+private inline fun String.forEachHexChunk(block: (char1: Char, char2: Char) -> Unit) {
+    var chunkSize = 0
+    var char1: Char = 0.toChar()
+    for ((index, c) in this.withIndex()) { // compiler optimization
+        if (c == ' ') {
+            if (chunkSize != 0) {
+                throw IllegalArgumentException("Invalid size of chunk at index ${index.minus(1)}")
+            }
+            continue
+        }
+        if (c in 'a'..'f' || c in 'A'..'F' || c in '0'..'9') { // compiler optimization
+            when (chunkSize) {
+                0 -> {
+                    chunkSize = 1
+                    char1 = c
+                }
+                1 -> {
+                    block(char1, c)
+                    chunkSize = 0
+                }
+            }
+        } else {
+            throw IllegalArgumentException("Invalid char '$c' at index $index")
+        }
+    }
+    if (chunkSize != 0) {
+        throw IllegalArgumentException("Invalid size of chunk at end of string")
+    }
+}
+
+public fun String.countHexBytes(): Int {
+    var chunkSize = 0
+    var count = 0
+    for ((index, c) in this.withIndex()) {
+        if (c == ' ') {
+            if (chunkSize != 0) {
+                throw IllegalArgumentException("Invalid size of chunk at index ${index.minus(1)}")
+            }
+            continue
+        }
+        c.isDigit()
+        if (c in 'a'..'f' || c in 'A'..'F' || c in '0'..'9') {
+            when (chunkSize) {
+                0 -> {
+                    chunkSize = 1
+                }
+                1 -> {
+                    count++
+                    chunkSize = 0
+                }
+            }
+        } else {
+            throw IllegalArgumentException("Invalid char '$c' at index $index")
+        }
+    }
+    if (chunkSize != 0) {
+        throw IllegalArgumentException("Invalid size of chunk at end of string")
+    }
+    return count
+}

+ 137 - 0
mirai-core-utils/src/commonTest/kotlin/net/mamoe/mirai/utils/HexToBytesTest.kt

@@ -0,0 +1,137 @@
+/*
+ * Copyright 2019-2021 Mamoe Technologies and contributors.
+ *
+ * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证.
+ * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link.
+ *
+ * https://github.com/mamoe/mirai/blob/dev/LICENSE
+ */
+
+package net.mamoe.mirai.utils
+
+import kotlin.test.Test
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
+
+class HexToBytesTest {
+    private fun Byte.Companion.parseFromHexChunk(char1: String): Byte = parseFromHexChunk(char1[0], char1[1])
+
+    @Test
+    fun `Byte parseFromHexChunk`() {
+        assertEquals(0xff.toByte(), Byte.parseFromHexChunk("FF"))
+        assertEquals(0xff.toByte(), Byte.parseFromHexChunk("ff"))
+        assertEquals(0xff.toByte(), Byte.parseFromHexChunk("fF"))
+        assertEquals(0xff.toByte(), Byte.parseFromHexChunk("Ff"))
+
+        assertEquals(0x00.toByte(), Byte.parseFromHexChunk("00"))
+        assertEquals(0x0f.toByte(), Byte.parseFromHexChunk("0f"))
+        assertEquals(0x34.toByte(), Byte.parseFromHexChunk("34"))
+        assertEquals(0x7f.toByte(), Byte.parseFromHexChunk("7f"))
+    }
+
+    @Test
+    fun `test countHexBytes`() {
+        assertEquals(0, "".countHexBytes())
+
+        assertEquals(1, "01".countHexBytes())
+        assertEquals(1, "FF".countHexBytes())
+        assertEquals(1, "ff".countHexBytes())
+        assertEquals(1, "Ff".countHexBytes())
+        assertEquals(1, "fF".countHexBytes())
+        assertEquals(1, "0F".countHexBytes())
+        assertEquals(1, "F0".countHexBytes())
+        assertEquals(1, "0f".countHexBytes())
+        assertEquals(1, "f0".countHexBytes())
+
+        assertEquals(1, "01 ".countHexBytes())
+        assertEquals(1, "FF ".countHexBytes())
+        assertEquals(1, "ff ".countHexBytes())
+        assertEquals(1, "Ff ".countHexBytes())
+        assertEquals(1, "fF ".countHexBytes())
+        assertEquals(1, "0F ".countHexBytes())
+        assertEquals(1, "F0 ".countHexBytes())
+        assertEquals(1, "0f ".countHexBytes())
+        assertEquals(1, "f0 ".countHexBytes())
+
+        assertEquals(1, " 01 ".countHexBytes())
+        assertEquals(1, " FF ".countHexBytes())
+        assertEquals(1, " ff ".countHexBytes())
+        assertEquals(1, " Ff ".countHexBytes())
+        assertEquals(1, " fF ".countHexBytes())
+        assertEquals(1, " 0F ".countHexBytes())
+        assertEquals(1, " F0 ".countHexBytes())
+        assertEquals(1, " 0f ".countHexBytes())
+        assertEquals(1, " f0 ".countHexBytes())
+
+        assertEquals(1, " 01    ".countHexBytes())
+        assertEquals(1, " FF    ".countHexBytes())
+        assertEquals(1, " ff    ".countHexBytes())
+        assertEquals(1, " Ff    ".countHexBytes())
+        assertEquals(1, " fF    ".countHexBytes())
+        assertEquals(1, " 0F    ".countHexBytes())
+        assertEquals(1, " F0    ".countHexBytes())
+        assertEquals(1, " 0f    ".countHexBytes())
+        assertEquals(1, " f0    ".countHexBytes())
+
+        assertEquals(2, " 01   01   ".countHexBytes())
+        assertEquals(2, " FF   FF   ".countHexBytes())
+        assertEquals(2, " ff   ff   ".countHexBytes())
+        assertEquals(2, " Ff   Ff   ".countHexBytes())
+        assertEquals(2, " fF   fF   ".countHexBytes())
+        assertEquals(2, " 0F   0F   ".countHexBytes())
+        assertEquals(2, " F0   F0   ".countHexBytes())
+        assertEquals(2, " 0f   0f   ".countHexBytes())
+        assertEquals(2, " f0   f0   ".countHexBytes())
+
+        assertEquals(2, " 0101   ".countHexBytes())
+        assertEquals(2, " FFFF   ".countHexBytes())
+        assertEquals(2, " ffff   ".countHexBytes())
+        assertEquals(2, " FfFf   ".countHexBytes())
+        assertEquals(2, " fFfF   ".countHexBytes())
+        assertEquals(2, " 0F0F   ".countHexBytes())
+        assertEquals(2, " F0F0   ".countHexBytes())
+        assertEquals(2, " 0f0f   ".countHexBytes())
+        assertEquals(2, " f0f0   ".countHexBytes())
+
+        assertFailsWith<IllegalArgumentException> { "1".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "0_1".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "0 1".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "g".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "_".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "123".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "0x12".countHexBytes() }
+        assertFailsWith<IllegalArgumentException> { "12 3".countHexBytes() }
+    }
+
+    @Test
+    fun `test hexToBytes`() {
+        assertContentEquals(byteArrayOf(0xff.toByte()), "FF".hexToBytes())
+        assertContentEquals(byteArrayOf(0xff.toByte()), "ff".hexToBytes())
+        assertContentEquals(byteArrayOf(0xff.toByte()), "fF".hexToBytes())
+        assertContentEquals(byteArrayOf(0xff.toByte()), "Ff".hexToBytes())
+
+        assertContentEquals(byteArrayOf(0x00.toByte()), "00".hexToBytes())
+        assertContentEquals(byteArrayOf(0x0f.toByte()), "0f".hexToBytes())
+        assertContentEquals(byteArrayOf(0x34.toByte()), "34".hexToBytes())
+        assertContentEquals(byteArrayOf(0x7f.toByte()), "7f".hexToBytes())
+
+        assertContentEquals(byteArrayOf(0xff.toByte(), 0xff.toByte()), "     FF   FF  ".hexToBytes())
+        assertContentEquals(byteArrayOf(0xff.toByte(), 0xff.toByte()), "     ff   ff  ".hexToBytes())
+        assertContentEquals(byteArrayOf(0xff.toByte(), 0xff.toByte()), "     fF   fF  ".hexToBytes())
+        assertContentEquals(byteArrayOf(0xff.toByte(), 0xff.toByte()), "     Ff   Ff  ".hexToBytes())
+
+        assertContentEquals(byteArrayOf(0x00.toByte(), 0x00.toByte()), "     00   00  ".hexToBytes())
+        assertContentEquals(byteArrayOf(0x0f.toByte(), 0x0f.toByte()), "     0f   0f  ".hexToBytes())
+        assertContentEquals(byteArrayOf(0x34.toByte(), 0x34.toByte()), "     34   34  ".hexToBytes())
+        assertContentEquals(byteArrayOf(0x7f.toByte(), 0x7f.toByte()), "     7f   7f  ".hexToBytes())
+    }
+
+    @Test
+    fun `test hexToUBytes`() {
+        // implementations of hexToBytes and hexToUBytes are very similar.
+
+        assertContentEquals(ubyteArrayOf(0x7f.toUByte(), 0x7f.toUByte()), "     7f   7f  ".hexToUBytes())
+    }
+}
+

+ 7 - 7
mirai-core/src/commonMain/kotlin/message/MarketFaceImpl.kt

@@ -19,7 +19,7 @@ import net.mamoe.mirai.message.data.Dice
 import net.mamoe.mirai.message.data.MarketFace
 import net.mamoe.mirai.message.data.MarketFace
 import net.mamoe.mirai.message.data.Message
 import net.mamoe.mirai.message.data.Message
 import net.mamoe.mirai.message.data.MessageChain
 import net.mamoe.mirai.message.data.MessageChain
-import net.mamoe.mirai.utils.chunkedHexToBytes
+import net.mamoe.mirai.utils.hexToBytes
 
 
 @SerialName(MarketFace.SERIAL_NAME)
 @SerialName(MarketFace.SERIAL_NAME)
 @Serializable
 @Serializable
@@ -89,12 +89,12 @@ internal fun Dice.toJceStruct(): ImMsgBody.MarketFace {
  */
  */
 @Suppress("SpellCheckingInspection")
 @Suppress("SpellCheckingInspection")
 private val DICE_PC_FACE_IDS = mapOf(
 private val DICE_PC_FACE_IDS = mapOf(
-    1 to "E6EEDE15CDFBEB4DF0242448535354F1".chunkedHexToBytes(),
-    2 to "C5A95816FB5AFE34A58AF0E837A3B5A0".chunkedHexToBytes(),
-    3 to "382131D722EEA4624F087C5B8035AF5F".chunkedHexToBytes(),
-    4 to "FA90E956DCAD76742F2DB87723D3B669".chunkedHexToBytes(),
-    5 to "D51FA892017647431BB243920EC9FB8E".chunkedHexToBytes(),
-    6 to "7A2303AD80755FCB6BBFAC38327E0C01".chunkedHexToBytes(),
+    1 to "E6EEDE15CDFBEB4DF0242448535354F1".hexToBytes(),
+    2 to "C5A95816FB5AFE34A58AF0E837A3B5A0".hexToBytes(),
+    3 to "382131D722EEA4624F087C5B8035AF5F".hexToBytes(),
+    4 to "FA90E956DCAD76742F2DB87723D3B669".hexToBytes(),
+    5 to "D51FA892017647431BB243920EC9FB8E".hexToBytes(),
+    6 to "7A2303AD80755FCB6BBFAC38327E0C01".hexToBytes(),
 )
 )
 
 
 private fun ImMsgBody.MarketFace.toDiceOrNull(): Dice? {
 private fun ImMsgBody.MarketFace.toDiceOrNull(): Dice? {

+ 5 - 5
mirai-core/src/commonMain/kotlin/utils/crypto/ECDH.kt

@@ -11,8 +11,8 @@ package net.mamoe.mirai.internal.utils.crypto
 
 
 import kotlinx.serialization.Serializable
 import kotlinx.serialization.Serializable
 import kotlinx.serialization.Transient
 import kotlinx.serialization.Transient
-import net.mamoe.mirai.utils.chunkedHexToBytes
 import net.mamoe.mirai.utils.decodeBase64
 import net.mamoe.mirai.utils.decodeBase64
+import net.mamoe.mirai.utils.hexToBytes
 import java.security.KeyFactory
 import java.security.KeyFactory
 import java.security.spec.X509EncodedKeySpec
 import java.security.spec.X509EncodedKeySpec
 
 
@@ -40,8 +40,8 @@ internal interface ECDHKeyPair {
 
 
     object DefaultStub : ECDHKeyPair {
     object DefaultStub : ECDHKeyPair {
         val defaultPublicKey =
         val defaultPublicKey =
-            "04edb8906046f5bfbe9abbc5a88b37d70a6006bfbabc1f0cd49dfb33505e63efc5d78ee4e0a4595033b93d02096dcd3190279211f7b4f6785079e19004aa0e03bc".chunkedHexToBytes()
-        val defaultShareKey = "c129edba736f4909ecc4ab8e010f46a3".chunkedHexToBytes()
+            "04edb8906046f5bfbe9abbc5a88b37d70a6006bfbabc1f0cd49dfb33505e63efc5d78ee4e0a4595033b93d02096dcd3190279211f7b4f6785079e19004aa0e03bc".hexToBytes()
+        val defaultShareKey = "c129edba736f4909ecc4ab8e010f46a3".hexToBytes()
 
 
         override val privateKey: Nothing get() = error("stub!")
         override val privateKey: Nothing get() = error("stub!")
         override val publicKey: Nothing get() = error("stub!")
         override val publicKey: Nothing get() = error("stub!")
@@ -130,10 +130,10 @@ internal val publicKeyForVerify by lazy {
         .generatePublic(X509EncodedKeySpec("MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuJTW4abQJXeVdAODw1CamZH4QJZChyT08ribet1Gp0wpSabIgyKFZAOxeArcCbknKyBrRY3FFI9HgY1AyItH8DOUe6ajDEb6c+vrgjgeCiOiCVyum4lI5Fmp38iHKH14xap6xGaXcBccdOZNzGT82sPDM2Oc6QYSZpfs8EO7TYT7KSB2gaHz99RQ4A/Lel1Vw0krk+DescN6TgRCaXjSGn268jD7lOO23x5JS1mavsUJtOZpXkK9GqCGSTCTbCwZhI33CpwdQ2EHLhiP5RaXZCio6lksu+d8sKTWU1eEiEb3cQ7nuZXLYH7leeYFoPtbFV4RicIWp0/YG+RP7rLPCwIDAQAB".decodeBase64()))
         .generatePublic(X509EncodedKeySpec("MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuJTW4abQJXeVdAODw1CamZH4QJZChyT08ribet1Gp0wpSabIgyKFZAOxeArcCbknKyBrRY3FFI9HgY1AyItH8DOUe6ajDEb6c+vrgjgeCiOiCVyum4lI5Fmp38iHKH14xap6xGaXcBccdOZNzGT82sPDM2Oc6QYSZpfs8EO7TYT7KSB2gaHz99RQ4A/Lel1Vw0krk+DescN6TgRCaXjSGn268jD7lOO23x5JS1mavsUJtOZpXkK9GqCGSTCTbCwZhI33CpwdQ2EHLhiP5RaXZCio6lksu+d8sKTWU1eEiEb3cQ7nuZXLYH7leeYFoPtbFV4RicIWp0/YG+RP7rLPCwIDAQAB".decodeBase64()))
 }
 }
 internal val defaultInitialPublicKey: ECDHInitialPublicKey by lazy { ECDHInitialPublicKey(keyStr = "04EBCA94D733E399B2DB96EACDD3F69A8BB0F74224E2B44E3357812211D2E62EFBC91BB553098E25E33A799ADC7F76FEB208DA7C6522CDB0719A305180CC54A82E") }
 internal val defaultInitialPublicKey: ECDHInitialPublicKey by lazy { ECDHInitialPublicKey(keyStr = "04EBCA94D733E399B2DB96EACDD3F69A8BB0F74224E2B44E3357812211D2E62EFBC91BB553098E25E33A799ADC7F76FEB208DA7C6522CDB0719A305180CC54A82E") }
-private val signHead = "3059301306072a8648ce3d020106082a8648ce3d030107034200".chunkedHexToBytes()
+private val signHead = "3059301306072a8648ce3d020106082a8648ce3d030107034200".hexToBytes()
 
 
 internal fun String.adjustToPublicKey(): ECDHPublicKey {
 internal fun String.adjustToPublicKey(): ECDHPublicKey {
-    return this.chunkedHexToBytes().adjustToPublicKey()
+    return this.hexToBytes().adjustToPublicKey()
 }
 }
 
 
 internal fun ByteArray.adjustToPublicKey(): ECDHPublicKey {
 internal fun ByteArray.adjustToPublicKey(): ECDHPublicKey {

Diferenças do arquivo suprimidas por serem muito extensas
+ 1 - 2
mirai-core/src/jvmTest/kotlin/message/data/MessageRefineTest.kt


Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff