Bläddra i källkod

Fix CCM decryption and validation

Marcin Krzyzanowski 6 år sedan
förälder
incheckning
caeee3a2c3

+ 40 - 22
Sources/CryptoSwift/BlockMode/CCM.swift

@@ -158,7 +158,6 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
                 // y[i], where i is the counter. Can encrypt 1 block at a time
                 counter += 1
                 guard let S = try? S(i: counter) else { return Array(plaintext) }
-
                 let plaintextP = addPadding(Array(plaintext), blockSize: blockSize)
                 guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(plaintext) }
                 last_y = y.slice
@@ -183,32 +182,41 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
         return ciphertext + computedTag
     }
 
-    func decrypt(block ciphertext: ArraySlice<UInt8>) -> Array<UInt8> {
-        var result = Array<UInt8>(reserveCapacity: ciphertext.count)
+    // Decryption is stream
+    // CBC is block
+    private var accumulatedPlaintext: [UInt8] = []
 
-        var processed = 0
-        while processed < ciphertext.count {
-            // Need a full block here to update keystream and do CBC
-            if keystream.isEmpty || keystreamPosIdx == blockSize {
-                // y[i], where i is the counter. Can encrypt 1 block at a time
-                counter += 1
-                guard let S = try? S(i: counter) else { return Array(ciphertext) }
-                let plaintextP = addPadding(xor(ciphertext, S), blockSize: blockSize)
-                guard let y = cipherOperation(xor(last_y, plaintextP)) else { return Array(ciphertext) }
-                last_y = y.slice
-
-                keystream = S
-                keystreamPosIdx = 0
+    func decrypt(block ciphertext: ArraySlice<UInt8>) -> Array<UInt8> {
+        var output = Array<UInt8>(reserveCapacity: ciphertext.count)
+
+        do {
+            var currentCounter = counter
+            var processed = 0
+            while processed < ciphertext.count {
+                // Need a full block here to update keystream and do CBC
+                // New keystream for a new block
+                if keystream.isEmpty || keystreamPosIdx == blockSize {
+                    currentCounter += 1
+                    guard let S = try? S(i: currentCounter) else { return Array(ciphertext) }
+                    keystream = S
+                    keystreamPosIdx = 0
+                }
+
+                let xored: Array<UInt8> = xor(ciphertext[ciphertext.startIndex.advanced(by: processed)...], keystream[keystreamPosIdx...]) // plaintext
+                keystreamPosIdx += xored.count
+                processed += xored.count
+                output += xored
+                counter = currentCounter
             }
-
-            let xored: Array<UInt8> = xor(ciphertext[ciphertext.startIndex.advanced(by: processed)...], keystream[keystreamPosIdx...])
-            keystreamPosIdx += xored.count
-            processed += xored.count
-            result += xored
         }
+
+        // Accumulate plaintext for the MAC calculations at the end.
+        // It would be good to process it together though, here.
+        accumulatedPlaintext += output
+
         // Shouldn't return plaintext until validate tag.
         // With incremental update, can't validate tag until all block are processed.
-        return result
+        return output
     }
 
     func finalize(decrypt plaintext: ArraySlice<UInt8>) throws -> ArraySlice<UInt8> {
@@ -231,6 +239,16 @@ class CCMModeWorker: StreamModeWorker, SeekableModeWorker, CounterModeWorker, Fi
     }
 
     func didDecryptLast(bytes plaintext: ArraySlice<UInt8>) throws -> ArraySlice<UInt8> {
+
+        // Calculate Tag, from the last CBC block, for accumulated plaintext.
+        var processed = 0
+        for block in accumulatedPlaintext.batched(by: blockSize) {
+            let blockP = addPadding(Array(block), blockSize: blockSize)
+            guard let y = cipherOperation(xor(last_y, blockP)) else { return plaintext }
+            last_y = y.slice
+            processed += block.count
+        }
+        accumulatedPlaintext.removeFirst(processed)
         return plaintext
     }
 }

+ 3 - 9
Sources/CryptoSwift/StreamDecryptor.swift

@@ -28,24 +28,18 @@ final class StreamDecryptor: Cryptor, Updatable {
 
     // MARK: Updatable
     public func update(withBytes bytes: ArraySlice<UInt8>, isLast: Bool) throws -> Array<UInt8> {
-        // TODO: accumulate `worker.additionalBufferSize`
-        // and pass it to willDecrypt(), most likely it will contains MAC
         accumulated += bytes
 
-        // If a worker (eg CCM) can combine ciphertext + tag
-        // we need to remove tag from the ciphertext.
-        if !isLast && accumulated.count < worker.additionalBufferSize {
-            return []
-        }
+        let toProcess = accumulated.prefix(max(accumulated.count - worker.additionalBufferSize, 0))
 
         if var finalizingWorker = worker as? FinalizingDecryptModeWorker, isLast == true {
             // will truncate suffix if needed
-            accumulated = Array(try finalizingWorker.willDecryptLast(bytes: accumulated.slice))
+            try finalizingWorker.willDecryptLast(bytes: accumulated.slice)
         }
 
         var processedBytesCount = 0
         var plaintext = Array<UInt8>(reserveCapacity: bytes.count + worker.additionalBufferSize)
-        for chunk in accumulated.batched(by: blockSize) {
+        for chunk in toProcess.batched(by: blockSize) {
             plaintext += worker.decrypt(block: chunk)
             processedBytesCount += chunk.count
         }

+ 13 - 3
Tests/Tests/AESCCMTests.swift

@@ -1,6 +1,6 @@
 ////  CryptoSwift
 //
-//  Copyright (C) 2014-__YEAR__ Marcin Krzyżanowski <marcin@krzyzanowskim.com>
+//  Copyright (C) 2014-2018 Marcin Krzyżanowski <marcin@krzyzanowskim.com>
 //  This software is provided 'as-is', without any express or implied warranty.
 //
 //  In no event will the authors be held liable for any damages arising from the use of this software.
@@ -381,8 +381,18 @@ final class AESCCMTests: XCTestCase {
             return true
         }
 
-        for fixture in fixtures {
-            XCTAssertTrue(testEncrypt(fixture: fixture))
+        func testDecrypt(fixture: TestFixture) -> Bool {
+            let aes = try! AES(key: fixture.key, blockMode: CCM(iv: fixture.nonce, tagLength: fixture.tagLength, messageLength: fixture.plaintext.count /*- fixture.tagLength*/, additionalAuthenticatedData: fixture.aad), padding: .noPadding)
+            let plaintext = try! aes.decrypt(fixture.expected)
+            if plaintext != fixture.plaintext {
+                return false
+            }
+            return true
+        }
+
+        for (i,fixture) in fixtures.enumerated() {
+            XCTAssertTrue(testEncrypt(fixture: fixture), "Encryption failed")
+            XCTAssertTrue(testDecrypt(fixture: fixture), "(\(i) - Decryption failed.")
         }
     }