Przeglądaj źródła

Improve AES Updatable implementation. Don't add padding if it's not required for the block mode.

Marcin Krzyżanowski 8 lat temu
rodzic
commit
bea5cc1bce

+ 13 - 5
Sources/CryptoSwift/AES.swift

@@ -426,7 +426,7 @@ extension AES {
         mutating public func update<T: Collection>(withBytes bytes: T, isLast: Bool = false) throws -> Array<UInt8> where T.Iterator.Element == UInt8 {
             self.accumulated += bytes
 
-            if isLast {
+            if isLast && self.paddingRequired {
                 self.accumulated = padding.add(to: self.accumulated, blockSize: AES.blockSize)
             }
 
@@ -542,14 +542,16 @@ extension AES: Cryptors {
 extension AES: Cipher {
 
     public func encrypt<C: Collection>(_ bytes: C) throws -> Array<UInt8> where C.Iterator.Element == UInt8, C.IndexDistance == Int, C.Index == Int {
-        let chunks = Array(bytes).chunks(size: AES.blockSize)
+        let chunks = bytes.batched(by: AES.blockSize)
 
         var oneTimeCryptor = self.makeEncryptor()
         var out = Array<UInt8>()
         out.reserveCapacity(bytes.count)
         for idx in chunks.indices {
-            out += try oneTimeCryptor.update(withBytes: chunks[idx], isLast: idx == chunks.endIndex.advanced(by: -1))
+            out += try oneTimeCryptor.update(withBytes: chunks[idx] as! ArraySlice<UInt8>, isLast: false)
         }
+        // Padding may be added at the very end
+        out += try oneTimeCryptor.finish()
 
         if blockMode.options.contains(.PaddingRequired) && (out.count % AES.blockSize != 0) {
             throw Error.dataPaddingRequired
@@ -564,11 +566,17 @@ extension AES: Cipher {
         }
 
         var oneTimeCryptor = self.makeDecryptor()
-        let chunks = Array(bytes).chunks(size: AES.blockSize)
+        let chunks = bytes.batched(by: AES.blockSize)
         var out = Array<UInt8>()
         out.reserveCapacity(bytes.count)
+
+        var lastIdx = chunks.startIndex
+        chunks.indices.formIndex(&lastIdx, offsetBy: chunks.count - 1)
+
+        // To properly remove padding, `isLast` has to be known when called with the last chunk of ciphertext
+        // Last chunk of ciphertext may contains padded data so next call to update(..) won't be able to remove it
         for idx in chunks.indices {
-            out += try oneTimeCryptor.update(withBytes: chunks[idx], isLast: idx == chunks.endIndex.advanced(by: -1))
+            out += try oneTimeCryptor.update(withBytes: chunks[idx] as! ArraySlice<UInt8>, isLast: idx == lastIdx)
         }
         return out
     }

+ 1 - 1
Sources/CryptoSwift/BlockCipher.swift

@@ -6,6 +6,6 @@
 //  Copyright © 2016 Marcin Krzyzanowski. All rights reserved.
 //
 
-protocol BlockCipher: class {
+protocol BlockCipher: Cipher {
     static var blockSize: Int { get }
 }

+ 2 - 2
Sources/CryptoSwift/UInt64+Extension.swift

@@ -10,12 +10,12 @@
 extension UInt64 {
 
     @_specialize(ArraySlice<UInt8>)
-    init<T: Collection>(bytes: T) where T.Iterator.Element == UInt8, T.Index == Int {
+    init<T: Collection>(bytes: T) where T.Iterator.Element == UInt8, T.IndexDistance == Int , T.Index == Int, T.SubSequence: Collection {
         self = UInt64(bytes: bytes, fromIndex: bytes.startIndex)
     }
 
     @_specialize(ArraySlice<UInt8>)
-    init<T: Collection>(bytes: T, fromIndex index: T.Index) where T.Iterator.Element == UInt8, T.Index == Int {
+    init<T: Collection>(bytes: T, fromIndex index: T.Index) where T.Iterator.Element == UInt8, T.IndexDistance == Int, T.Index == Int, T.SubSequence: Collection {
         let val0 = UInt64(bytes[index.advanced(by: 0)]) << 56
         let val1 = UInt64(bytes[index.advanced(by: 1)]) << 48
         let val2 = UInt64(bytes[index.advanced(by: 2)]) << 40

+ 11 - 10
Sources/CryptoSwift/Updatable.swift

@@ -11,18 +11,19 @@
 public protocol Updatable {
     /// Update given bytes in chunks.
     ///
-    /// - parameter bytes: Bytes to process
-    /// - parameter isLast: (Optional) Given chunk is the last one. No more updates after this call.
+    /// - parameter bytes: Bytes to process.
+    /// - parameter isLast: Indicate if given chunk is the last one. No more updates after this call.
     /// - returns: Processed data or empty array.
     mutating func update<T: Collection>(withBytes bytes: T, isLast: Bool) throws -> Array<UInt8> where T.Iterator.Element == UInt8
 
     /// Update given bytes in chunks.
     ///
-    /// - parameter bytes: Bytes to process
-    /// - parameter isLast: (Optional) Given chunk is the last one. No more updates after this call.
-    /// - parameter output: Resulting data
-    /// - returns: Processed data or empty array.
-    mutating func update<T: Collection>(withBytes bytes: T, isLast: Bool, output: (Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8
+    /// - Parameters:
+    ///   - bytes: Bytes to process.
+    ///   - isLast: Indicate if given chunk is the last one. No more updates after this call.
+    ///   - output: Resulting bytes callback.
+    /// - Returns: Processed data or empty array.
+    mutating func update<T: Collection>(withBytes bytes: T, isLast: Bool, output: (_ bytes: Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8
 
     /// Finish updates. This may apply padding.
     /// - parameter bytes: Bytes to process
@@ -33,12 +34,12 @@ public protocol Updatable {
     /// - parameter bytes: Bytes to process
     /// - parameter output: Resulting data
     /// - returns: Processed data.
-    mutating func finish<T: Collection>(withBytes bytes: T, output: (Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8
+    mutating func finish<T: Collection>(withBytes bytes: T, output: (_ bytes: Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8
 }
 
 extension Updatable {
 
-    mutating public func update<T: Collection>(withBytes bytes: T, isLast: Bool = false, output: (Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8 {
+    mutating public func update<T: Collection>(withBytes bytes: T, isLast: Bool = false, output: (_ bytes: Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8 {
         let processed = try self.update(withBytes: bytes, isLast: isLast)
         if (!processed.isEmpty) {
             output(processed)
@@ -53,7 +54,7 @@ extension Updatable {
         return try self.update(withBytes: [], isLast: true)
     }
 
-    mutating public func finish<T: Collection>(withBytes bytes: T, output: (Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8 {
+    mutating public func finish<T: Collection>(withBytes bytes: T, output: (_ bytes: Array<UInt8>) -> Void) throws where T.Iterator.Element == UInt8 {
         let processed = try self.update(withBytes: bytes, isLast: true)
         if (!processed.isEmpty) {
             output(processed)