Parcourir la source

Accept string without "=" padding

Norio Nomura il y a 10 ans
Parent
commit
cd5bd20404
2 fichiers modifiés avec 63 ajouts et 16 suppressions
  1. 16 16
      Base32/Base32.swift
  2. 47 0
      Base32Tests/Base32Tests.swift

+ 16 - 16
Base32/Base32.swift

@@ -301,23 +301,10 @@ private func base32decode(string: String, table: [UInt8]) -> [UInt8]? {
             return 0
         }
     }
-    // calc padded bytes
-    func paddedBytes(paddingLength: Int) -> Int {
-        switch paddingLength {
-        case 6: return 4
-        case 4: return 3
-        case 3: return 2
-        case 1: return 1
-        default: /* case 0:*/ return 0
-        }
-    }
     
     // validate string
     let leastPaddingLength = getLeastPaddingLength(string)
-    if length % 8 != 0 {
-        println("string length is invalid.")
-        return nil
-    } else if let index = index_of(string.unicodeScalars, {$0.value > 0xff || table[Int($0.value)] > 31}) {
+    if let index = index_of(string.unicodeScalars, {$0.value > 0xff || table[Int($0.value)] > 31}) {
         // index points padding "=" or invalid character that table does not contain.
         let pos = distance(string.unicodeScalars.startIndex, index)
         // if pos points padding "=", it's valid.
@@ -327,9 +314,22 @@ private func base32decode(string: String, table: [UInt8]) -> [UInt8]? {
         }
     }
     
-    // validated
-    let dataSize = length / 8 * 5 - paddedBytes(leastPaddingLength)
     var remainEncodedLength = length - leastPaddingLength
+    var additionalBytes = 0
+    switch remainEncodedLength % 8 {
+        // valid
+    case 0: break
+    case 2: additionalBytes = 1
+    case 4: additionalBytes = 2
+    case 5: additionalBytes = 3
+    case 7: additionalBytes = 4
+    default:
+        println("string length is invalid.")
+        return nil
+    }
+    
+    // validated
+    let dataSize = remainEncodedLength / 8 * 5 + additionalBytes
     
     // Use UnsafePointer<UInt8>
     return string.nulTerminatedUTF8.withUnsafeBufferPointer {

+ 47 - 0
Base32Tests/Base32Tests.swift

@@ -143,4 +143,51 @@ class Base32Tests: XCTestCase {
         }
     }
     
+    // MARK:
+    
+    func test_DecodeStringAcceptableLengthPattern() {
+        // "=" stripped valid string
+        let strippedVectors = vectors.map {
+            (
+                $0,
+                $1.stringByReplacingOccurrencesOfString("=", withString:""),
+                $2.stringByReplacingOccurrencesOfString("=", withString:"")
+            )
+        }
+        for (expect, test, testHex) in strippedVectors {
+            let data = expect.dataUsingEncoding(NSUTF8StringEncoding, allowLossyConversion: false)
+            let result = base32DecodeToData(test)
+            let resultHex = base32HexDecodeToData(testHex)
+            XCTAssertEqual(result!, data!, "base32Decode for \(test)")
+            XCTAssertEqual(resultHex!, data!, "base32HexDecode for \(testHex)")
+        }
+        
+        // invalid length string with padding
+        let invalidVectorWithPaddings: [(String,String)] = [
+            ("M=======", "C======="),
+            ("MYZ=====", "COZ====="),
+            ("MZXW6Z==", "CPNMUZ=="),
+            ("MZXW6YTBO=======", "CPNMUOJ1E======="),
+        ]
+        for (test, testHex) in invalidVectorWithPaddings {
+            let result = base32DecodeToData(test)
+            let resultHex = base32HexDecodeToData(testHex)
+            XCTAssertNil(result, "base32Decode for \(test)")
+            XCTAssertNil(resultHex, "base32HexDecode for \(test)")
+        }
+        
+        // invalid length string without padding
+        let invalidVectorWithoutPaddings = invalidVectorWithPaddings.map {
+            (
+                $0.stringByReplacingOccurrencesOfString("=", withString:""),
+                $1.stringByReplacingOccurrencesOfString("=", withString:"")
+            )
+        }
+        for (test, testHex) in invalidVectorWithPaddings {
+            let result = base32DecodeToData(test)
+            let resultHex = base32HexDecodeToData(testHex)
+            XCTAssertNil(result, "base32Decode for \(test)")
+            XCTAssertNil(resultHex, "base32HexDecode for \(test)")
+        }
+    }
 }