浏览代码

CTRCounter

Marcin Krzyzanowski 7 年之前
父节点
当前提交
c95b227a9b

+ 2 - 0
Sources/CryptoSwift/BlockMode/BlockMode.swift

@@ -20,3 +20,5 @@ public protocol BlockMode {
     //TODO: doesn't have to be public
     func worker(blockSize: Int, cipherOperation: @escaping CipherOperationOnBlock) throws -> CipherModeWorker
 }
+
+typealias StreamMode = BlockMode

+ 37 - 10
Sources/CryptoSwift/BlockMode/CTR.swift

@@ -16,7 +16,7 @@
 //  Counter (CTR)
 //
 
-public struct CTR: BlockMode {
+public struct CTR: StreamMode {
     public enum Error: Swift.Error {
         /// Invalid IV
         case invalidInitializationVector
@@ -40,23 +40,50 @@ public struct CTR: BlockMode {
     }
 }
 
-struct CTRModeWorker: RandomAccessBlockModeWorker {
+struct CTRModeWorker: RandomAccessCipherModeWorker {
+    typealias Counter = CTRCounter
+
+    class CTRCounter {
+        private let constPrefix: Array<UInt8>
+        private var value: UInt64
+        //TODO: make it an updatable value, computing is too slow
+        var bytes: Array<UInt8> {
+            return constPrefix + value.bytes()
+        }
+
+        init(_ initialValue: Array<UInt8>) {
+            let halfIndex = initialValue.startIndex.advanced(by: initialValue.count / 2)
+            constPrefix = Array(initialValue[initialValue.startIndex..<halfIndex])
+
+            let suffixBytes = Array(initialValue[halfIndex..<initialValue.endIndex])
+            value = UInt64(bytes: suffixBytes)
+        }
+
+        convenience init(nonce: Array<UInt8>, startAt index: Int) {
+            self.init(buildCounterValue(nonce, counter: UInt64(index)))
+        }
+
+        static func +=(lhs: CTRCounter, rhs: Int) {
+            lhs.value += UInt64(rhs)
+        }
+    }
+
+
     let cipherOperation: CipherOperationOnBlock
     let additionalBufferSize: Int = 0
-    private let iv: ArraySlice<UInt8>
-    var counter: UInt = 0
+    var counter: Counter
 
     init(blockSize: Int, iv: ArraySlice<UInt8>, counter: Int, cipherOperation: @escaping CipherOperationOnBlock) {
-        self.iv = iv
-        self.counter = UInt(counter)
+        self.counter = Counter(nonce: Array(iv), startAt: counter)
         self.cipherOperation = cipherOperation
     }
 
     mutating func encrypt(block plaintext: ArraySlice<UInt8>) -> Array<UInt8> {
-        let nonce = buildNonce(iv, counter: UInt64(counter))
-        defer { counter += 1 }
+        defer {
+            counter += 1
+        }
 
-        guard let ciphertext = cipherOperation(nonce.slice) else {
+        guard let ciphertext = cipherOperation(counter.bytes.slice) else {
             return Array(plaintext)
         }
 
@@ -68,7 +95,7 @@ struct CTRModeWorker: RandomAccessBlockModeWorker {
     }
 }
 
-private func buildNonce(_ iv: ArraySlice<UInt8>, counter: UInt64) -> Array<UInt8> {
+private func buildCounterValue(_ iv: Array<UInt8>, counter: UInt64) -> Array<UInt8> {
     let noncePartLen = iv.count / 2
     let noncePrefix = iv[iv.startIndex..<iv.startIndex.advanced(by: noncePartLen)]
     let nonceSuffix = iv[iv.startIndex.advanced(by: noncePartLen)..<iv.startIndex.advanced(by: iv.count)]

+ 3 - 2
Sources/CryptoSwift/BlockMode/CipherModeWorker.swift

@@ -28,8 +28,9 @@ public protocol BlockModeWorker: CipherModeWorker {
     var blockSize: Int { get }
 }
 
-protocol RandomAccessBlockModeWorker: CipherModeWorker {
-    var counter: UInt { set get }
+protocol RandomAccessCipherModeWorker: CipherModeWorker {
+     associatedtype Counter
+     var counter: Counter { set get }
 }
 
 // TODO: remove and merge with BlockModeWorker

+ 1 - 1
Tests/Tests/AESTests.swift

@@ -120,7 +120,7 @@ final class AESTests: XCTestCase {
             ciphertext += try encryptor.finish()
             XCTAssertEqual(try aes.encrypt(plaintext.bytes), ciphertext, "encryption failed")
         } catch {
-            XCTAssert(false, "\(error)")
+            XCTFail("\(error)")
         }
     }