Bladeren bron

CCM decryption

Marcin Krzyzanowski 6 jaren geleden
bovenliggende
commit
d28b99527b
3 gewijzigde bestanden met toevoegingen van 72 en 19 verwijderingen
  1. 53 12
      Sources/CryptoSwift/BlockMode/CCM.swift
  2. 3 3
      Sources/CryptoSwift/BlockMode/GCM.swift
  3. 16 4
      Tests/Tests/AESTests.swift

+ 53 - 12
Sources/CryptoSwift/BlockMode/CCM.swift

@@ -28,21 +28,35 @@ public struct CCM: StreamMode {
         /// Invalid IV
         /// Invalid IV
         case invalidInitializationVector
         case invalidInitializationVector
         case invalidParameter
         case invalidParameter
+        case fail
     }
     }
 
 
-    public let options: BlockModeOption = [.initializationVectorRequired]
+    public let options: BlockModeOption = [.initializationVectorRequired, .useEncryptToDecrypt]
     private let nonce: Array<UInt8>
     private let nonce: Array<UInt8>
     private let additionalAuthenticatedData: Array<UInt8>?
     private let additionalAuthenticatedData: Array<UInt8>?
     private let tagLength: Int
     private let tagLength: Int
     private let messageLength: Int // total message length. need to know in advance
     private let messageLength: Int // total message length. need to know in advance
 
 
+    // `authenticationTag` nil for encryption, known tag for decryption
+    /// For encryption, the value is set at the end of the encryption.
+    /// For decryption, this is a known Tag to validate against.
+    public var authenticationTag: Array<UInt8>?
+
+    // encrypt
     public init(nonce: Array<UInt8>, tagLength: Int, messageLength: Int, additionalAuthenticatedData: Array<UInt8>? = nil) {
     public init(nonce: Array<UInt8>, tagLength: Int, messageLength: Int, additionalAuthenticatedData: Array<UInt8>? = nil) {
         self.nonce = nonce
         self.nonce = nonce
         self.tagLength = tagLength
         self.tagLength = tagLength
         self.additionalAuthenticatedData = additionalAuthenticatedData
         self.additionalAuthenticatedData = additionalAuthenticatedData
-        self.messageLength = messageLength
+        self.messageLength = messageLength - tagLength
+    }
+
+    // decrypt
+    public init(nonce: Array<UInt8>, tagLength: Int, messageLength: Int, authenticationTag: Array<UInt8>, additionalAuthenticatedData: Array<UInt8>? = nil) {
+        self.init(nonce: nonce, tagLength: tagLength, messageLength: messageLength, additionalAuthenticatedData: additionalAuthenticatedData)
+        self.authenticationTag = authenticationTag
     }
     }
 
 
+
     public func worker(blockSize: Int, cipherOperation: @escaping CipherOperationOnBlock) throws -> CipherModeWorker {
     public func worker(blockSize: Int, cipherOperation: @escaping CipherOperationOnBlock) throws -> CipherModeWorker {
         if nonce.isEmpty {
         if nonce.isEmpty {
             throw Error.invalidInitializationVector
             throw Error.invalidInitializationVector
@@ -74,11 +88,12 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
         case invalidParameter
         case invalidParameter
     }
     }
 
 
-    init(blockSize: Int, nonce: ArraySlice<UInt8>, messageLength: Int,  additionalAuthenticatedData: [UInt8]?, tagLength: Int, cipherOperation: @escaping CipherOperationOnBlock) {
-        self.blockSize = blockSize
+    init(blockSize: Int, nonce: ArraySlice<UInt8>, messageLength: Int,  additionalAuthenticatedData: [UInt8]?, expectedTag: Array<UInt8>? = nil, tagLength: Int, cipherOperation: @escaping CipherOperationOnBlock) {
+        self.blockSize = 16// blockSize
         self.tagLength = tagLength
         self.tagLength = tagLength
         self.additionalBufferSize = tagLength
         self.additionalBufferSize = tagLength
         self.messageLength = messageLength
         self.messageLength = messageLength
+        self.expectedTag = expectedTag
         self.cipherOperation = cipherOperation
         self.cipherOperation = cipherOperation
         self.nonce = Array(nonce)
         self.nonce = Array(nonce)
         self.q = UInt8(15 - nonce.count) // n = 15-q
         self.q = UInt8(15 - nonce.count) // n = 15-q
@@ -130,7 +145,7 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
                 counter += 1
                 counter += 1
                 guard let S = try? S(i: counter) else { return Array(plaintext) }
                 guard let S = try? S(i: counter) else { return Array(plaintext) }
 
 
-                let plaintextP = addPadding(Array(plaintext), blockSize: 16)
+                let plaintextP = addPadding(Array(plaintext), blockSize: blockSize)
                 guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(plaintext) }
                 guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(plaintext) }
                 last_y = y.slice
                 last_y = y.slice
 
 
@@ -150,21 +165,47 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
         // concatenate T at the end
         // concatenate T at the end
         guard let S0 = try? S(i: 0) else { return ciphertext }
         guard let S0 = try? S(i: 0) else { return ciphertext }
 
 
-        let tag = last_y.prefix(tagLength)
-        return ciphertext + (xor(tag, S0) as ArraySlice<UInt8>)
+        let tag = xor(last_y.prefix(tagLength), S0) as ArraySlice<UInt8>
+        return ciphertext + tag
     }
     }
 
 
-    // TODO
     func decrypt(block ciphertext: ArraySlice<UInt8>) -> Array<UInt8> {
     func decrypt(block ciphertext: ArraySlice<UInt8>) -> Array<UInt8> {
-        guard let plaintext = cipherOperation(ciphertext) else {
-            return Array(ciphertext)
+        var result = Array<UInt8>(reserveCapacity: ciphertext.count)
+
+        var processed = 0
+        while processed < ciphertext.count {
+            // Need a full block here to update keystream and do CBC
+            if keystream.isEmpty || keystreamPosIdx == blockSize {
+                // y[i], where i is the counter. Can encrypt 1 block at a time
+                counter += 1
+                guard let S = try? S(i: counter) else { return Array(ciphertext) }
+                let plaintextP = addPadding(xor(ciphertext, S), blockSize: blockSize)
+                guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(ciphertext) }
+                last_y = y.slice
+
+                keystream = S
+                keystreamPosIdx = 0
+            }
+
+            let xored: Array<UInt8> = xor(ciphertext[ciphertext.startIndex.advanced(by: processed)...], keystream[keystreamPosIdx...])
+            keystreamPosIdx += xored.count
+            processed += xored.count
+            result += xored
         }
         }
-        let result: Array<UInt8> = xor(last_y, plaintext)
-        last_y = ciphertext
+        // Shouldn't return plaintext until validate tag.
+        // With incremental update, can't validate tag until all block are processed.
         return result
         return result
     }
     }
 
 
     func finalize(decrypt plaintext: ArraySlice<UInt8>) throws -> ArraySlice<UInt8> {
     func finalize(decrypt plaintext: ArraySlice<UInt8>) throws -> ArraySlice<UInt8> {
+        // concatenate T at the end
+        guard let S0 = try? S(i: 0) else { return plaintext }
+
+        let computedTag = xor(last_y.prefix(tagLength), S0) as Array<UInt8>
+        guard let expectedTag = self.expectedTag, expectedTag == computedTag else {
+            throw CCM.Error.fail
+        }
+
         return plaintext
         return plaintext
     }
     }
 
 

+ 3 - 3
Sources/CryptoSwift/BlockMode/GCM.swift

@@ -204,11 +204,11 @@ final class GCMModeWorker: BlockModeWorker, FinalizingEncryptModeWorker, Finaliz
         let computedTag = Array((ghash ^ eky0).bytes.prefix(GCMModeWorker.tagLength))
         let computedTag = Array((ghash ^ eky0).bytes.prefix(GCMModeWorker.tagLength))
 
 
         // Validate tag
         // Validate tag
-        if let expectedTag = self.expectedTag, computedTag == expectedTag {
-            return plaintext
+        guard let expectedTag = self.expectedTag, computedTag == expectedTag else {
+            throw GCM.Error.fail
         }
         }
 
 
-        throw GCM.Error.fail
+        return plaintext
     }
     }
 
 
     func finalize(decrypt plaintext: ArraySlice<UInt8>) throws -> ArraySlice<UInt8> {
     func finalize(decrypt plaintext: ArraySlice<UInt8>) throws -> ArraySlice<UInt8> {

+ 16 - 4
Tests/Tests/AESTests.swift

@@ -386,7 +386,7 @@ final class AESTests: XCTestCase {
     }
     }
 }
 }
 
 
-// GCM Test Vectors
+// MARK: - GCM
 extension AESTests {
 extension AESTests {
     func testAESGCMTestCase1() {
     func testAESGCMTestCase1() {
         // Test Case 1
         // Test Case 1
@@ -573,7 +573,10 @@ extension AESTests {
         let decrypted = decrypt(encrypted)
         let decrypted = decrypt(encrypted)
         XCTAssertEqual(decrypted, plaintext)
         XCTAssertEqual(decrypted, plaintext)
     }
     }
+}
 
 
+// MARK: - CCM
+extension AESTests {
     func testAESCCMTestCase1() {
     func testAESCCMTestCase1() {
         let key: Array<UInt8> =       [0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f];
         let key: Array<UInt8> =       [0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f];
         let nonce: Array<UInt8> =     [0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]
         let nonce: Array<UInt8> =     [0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]
@@ -584,8 +587,18 @@ extension AESTests {
         let aes = try! AES(key: key, blockMode: CCM(nonce: nonce, tagLength: 4, messageLength: plaintext.count, additionalAuthenticatedData: aad), padding: .noPadding)
         let aes = try! AES(key: key, blockMode: CCM(nonce: nonce, tagLength: 4, messageLength: plaintext.count, additionalAuthenticatedData: aad), padding: .noPadding)
         let encrypted = try! aes.encrypt(plaintext)
         let encrypted = try! aes.encrypt(plaintext)
         XCTAssertEqual(encrypted, expected, "encryption failed")
         XCTAssertEqual(encrypted, expected, "encryption failed")
-        // let decrypted = try! aes.decrypt(encrypted)
-        // XCTAssertEqual(decrypted, plaintext, "decryption failed")
+    }
+
+    func testAESCCMTestCase1Decrypt() {
+        let key: Array<UInt8> =       [0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f];
+        let nonce: Array<UInt8> =     [0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]
+        let aad: Array<UInt8> =       [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
+        let ciphertext: Array<UInt8> = [0x71, 0x62, 0x01, 0x5b, 0x4d, 0xac, 0x25, 0x5d]
+        let expected: Array<UInt8> = [0x20, 0x21, 0x22, 0x23]
+
+        let aes = try! AES(key: key, blockMode: CCM(nonce: nonce, tagLength: 4, messageLength: ciphertext.count, additionalAuthenticatedData: aad), padding: .noPadding)
+        let decrypted = try! aes.decrypt(ciphertext)
+        XCTAssertEqual(decrypted, expected, "decryption failed")
     }
     }
 
 
     func testAESCCMTestCase2() {
     func testAESCCMTestCase2() {
@@ -611,7 +624,6 @@ extension AESTests {
         let encrypted = try! aes.encrypt(plaintext)
         let encrypted = try! aes.encrypt(plaintext)
         XCTAssertEqual(encrypted, expected, "encryption failed")
         XCTAssertEqual(encrypted, expected, "encryption failed")
     }
     }
-
 }
 }
 
 
 extension AESTests {
 extension AESTests {