Эх сурвалжийг харах

Add testAESCCMTestCase2. Fix padding. Fix counter

Marcin Krzyzanowski 6 жил өмнө
parent
commit
acee36d5bf

+ 25 - 7
Sources/CryptoSwift/BlockMode/CCM.swift

@@ -96,7 +96,6 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
         for block_i in encodedAAD.batched(by: 16) {
         for block_i in encodedAAD.batched(by: 16) {
             let y_i = cipherOperation(xor(block_i, last_y))!.slice
             let y_i = cipherOperation(xor(block_i, last_y))!.slice
             last_y = y_i
             last_y = y_i
-            counter += 1
         }
         }
     }
     }
 
 
@@ -120,13 +119,13 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
             // Need a full block here to update keystream and do CBC
             // Need a full block here to update keystream and do CBC
             if keystream.isEmpty || keystreamPosIdx == blockSize {
             if keystream.isEmpty || keystreamPosIdx == blockSize {
                 // y[i], where i is the counter. Can encrypt 1 block at a time
                 // y[i], where i is the counter. Can encrypt 1 block at a time
+                counter += 1
                 guard let S = try? S(i: counter) else { return Array(plaintext) }
                 guard let S = try? S(i: counter) else { return Array(plaintext) }
 
 
-                let plaintextP = ZeroPadding().add(to: Array(plaintext), blockSize: 16)
+                let plaintextP = addPadding(Array(plaintext), blockSize: 16)
                 guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(plaintext) }
                 guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(plaintext) }
                 last_y = y.slice
                 last_y = y.slice
 
 
-                counter += 1
                 keystream = S
                 keystream = S
                 keystreamPosIdx = 0
                 keystreamPosIdx = 0
             }
             }
@@ -238,15 +237,34 @@ private func format(aad: [UInt8]) -> [UInt8] {
     switch Double(a) {
     switch Double(a) {
     case 0..<65280: // 2^16-2^8
     case 0..<65280: // 2^16-2^8
         // [a]16
         // [a]16
-        return ZeroPadding().add(to: a.bytes(totalBytes: 2) + aad, blockSize: 16)
+        return addPadding(a.bytes(totalBytes: 2) + aad, blockSize: 16)
     case 65280..<4_294_967_296: // 2^32
     case 65280..<4_294_967_296: // 2^32
         // [a]32
         // [a]32
-        return ZeroPadding().add(to: [0xFF, 0xFE] + a.bytes(totalBytes: 4) + aad, blockSize: 16)
+        return addPadding([0xFF, 0xFE] + a.bytes(totalBytes: 4) + aad, blockSize: 16)
     case 4_294_967_296..<pow(2,64): // 2^64
     case 4_294_967_296..<pow(2,64): // 2^64
         // [a]64
         // [a]64
-        return ZeroPadding().add(to: [0xFF, 0xFF] + a.bytes(totalBytes: 8) + aad, blockSize: 16)
+        return addPadding([0xFF, 0xFF] + a.bytes(totalBytes: 8) + aad, blockSize: 16)
     default:
     default:
         // Reserved
         // Reserved
-        return ZeroPadding().add(to: aad, blockSize: 16)
+        return addPadding(aad, blockSize: 16)
+    }
+}
+
+// If data is not a multiple of block size bytes long then the remainder is zero padded
+// Note: It's similar to ZeroPadding, but it's not the same.
+private func addPadding(_ bytes: Array<UInt8>, blockSize: Int) -> Array<UInt8> {
+    if bytes.isEmpty {
+        return Array<UInt8>(repeating: 0, count: blockSize)
+    }
+
+    let remainder = bytes.count % blockSize
+    if remainder == 0 {
+        return bytes
+    }
+
+    let paddingCount = blockSize - remainder
+    if paddingCount > 0 {
+        return bytes + Array<UInt8>(repeating: 0, count: paddingCount)
     }
     }
+    return bytes
 }
 }

+ 20 - 7
Tests/Tests/AESTests.swift

@@ -575,17 +575,29 @@ extension AESTests {
     }
     }
 
 
     func testAESCCMTestCase1() {
     func testAESCCMTestCase1() {
-        let key: Array<UInt8> = [0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f];
-        let nonce: Array<UInt8> = [0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]
-        let aad: Array<UInt8> = [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
+        let key: Array<UInt8> =       [0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f];
+        let nonce: Array<UInt8> =     [0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]
+        let aad: Array<UInt8> =       [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
         let plaintext: Array<UInt8> = [0x20, 0x21, 0x22, 0x23]
         let plaintext: Array<UInt8> = [0x20, 0x21, 0x22, 0x23]
-        let expected: Array<UInt8> = [0x71, 0x62, 0x01, 0x5b, 0x4d, 0xac, 0x25, 0x5d]
+        let expected: Array<UInt8> =  [0x71, 0x62, 0x01, 0x5b, 0x4d, 0xac, 0x25, 0x5d]
 
 
         let aes = try! AES(key: key, blockMode: CCM(nonce: nonce, tagLength: 4, messageLength: plaintext.count, additionalAuthenticatedData: aad), padding: .noPadding)
         let aes = try! AES(key: key, blockMode: CCM(nonce: nonce, tagLength: 4, messageLength: plaintext.count, additionalAuthenticatedData: aad), padding: .noPadding)
         let encrypted = try! aes.encrypt(plaintext)
         let encrypted = try! aes.encrypt(plaintext)
         XCTAssertEqual(encrypted, expected, "encryption failed")
         XCTAssertEqual(encrypted, expected, "encryption failed")
-//        let decrypted = try! aes.decrypt(encrypted)
-//        XCTAssertEqual(decrypted, plaintext, "decryption failed")
+        // let decrypted = try! aes.decrypt(encrypted)
+        // XCTAssertEqual(decrypted, plaintext, "decryption failed")
+    }
+
+    func testAESCCMTestCase2() {
+        let key: Array<UInt8> =       [0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f];
+        let nonce: Array<UInt8> =     [0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17]
+        let aad: Array<UInt8> =       [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f]
+        let plaintext: Array<UInt8> = [0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f]
+        let expected: Array<UInt8>  = [0xd2, 0xa1, 0xf0, 0xe0, 0x51, 0xea, 0x5f, 0x62, 0x08, 0x1a, 0x77, 0x92, 0x07, 0x3d, 0x59, 0x3d, 0x1f, 0xc6, 0x4f, 0xbf, 0xac, 0xcd]
+
+        let aes = try! AES(key: key, blockMode: CCM(nonce: nonce, tagLength: 6, messageLength: plaintext.count, additionalAuthenticatedData: aad), padding: .noPadding)
+        let encrypted = try! aes.encrypt(plaintext)
+        XCTAssertEqual(encrypted, expected, "encryption failed")
     }
     }
 
 
 }
 }
@@ -626,7 +638,8 @@ extension AESTests {
             ("testAESGCMTestCase7", testAESGCMTestCase7),
             ("testAESGCMTestCase7", testAESGCMTestCase7),
             ("testAESGCMTestCaseIrregularCombined1", testAESGCMTestCaseIrregularCombined1),
             ("testAESGCMTestCaseIrregularCombined1", testAESGCMTestCaseIrregularCombined1),
             ("testAESGCMTestCaseIrregularCombined2", testAESGCMTestCaseIrregularCombined2),
             ("testAESGCMTestCaseIrregularCombined2", testAESGCMTestCaseIrregularCombined2),
-            ("testAESCCMTestCase1", testAESCCMTestCase1)
+            ("testAESCCMTestCase1", testAESCCMTestCase1),
+            ("testAESCCMTestCase2", testAESCCMTestCase2)
         ]
         ]
         return tests
         return tests
     }
     }