Browse Source

shiftLeft as Generic - tryout

Marcin Krzyżanowski 11 years ago
parent
commit
7704ab0f34

+ 81 - 2
CryptoSwift/Generics.swift

@@ -11,9 +11,10 @@ import Foundation
 /** Protocol and extensions for integerFromBitsArray. Bit hakish for me, but I can't do it in any other way */
 protocol Initiable  {
     init(_ v: Int)
+    init(_ v: UInt)
 }
 
-extension UInt:Initiable {}
+extension Int:Initiable {}
 extension UInt:Initiable {}
 extension UInt8:Initiable {}
 extension UInt16:Initiable {}
@@ -63,4 +64,82 @@ func arrayOfBytes<T>(value:T, totalBytes:Int) -> [Byte] {
     }
     
     return bytes
-}
+}
+
+// MARK: - shiftLeft
+
+// helper to be able tomake shift operation on T
+func <<<T:SignedIntegerType>(lhs: T, rhs: Int) -> Int {
+    let a = lhs as Int
+    let b = rhs
+    return a << b
+}
+
+func <<<T:UnsignedIntegerType>(lhs: T, rhs: Int) -> UInt {
+    let a = lhs as UInt
+    let b = rhs
+    return a << b
+}
+
+// Generic function itself
+// FIXME: this generic function is not as generic as I would. It crashes for smaller types
+func shiftLeft<T: SignedIntegerType where T: Initiable>(value: T, count: Int) -> T {
+    if (value == 0) {
+        return 0;
+    }
+    
+    var bitsCount = (sizeofValue(value) * 8)
+    var shiftCount = Int(Swift.min(count, bitsCount - 1))
+    
+    var shiftedValue:T = 0;
+    for bitIdx in 0..<bitsCount {
+        var bit = T.from(IntMax(1 << bitIdx))
+        if ((value & bit) == bit) {
+            shiftedValue = shiftedValue | T(bit << shiftCount)
+        }
+    }
+    
+    if (shiftedValue != 0 && count >= bitsCount) {
+        // clear last bit that couldn't be shifted out of range
+        shiftedValue = shiftedValue & T(~(1 << (bitsCount - 1)))
+    }
+    return shiftedValue
+}
+
+// for any f*** other Integer type - this part is so non-Generic
+func shiftLeft(value: UInt, count: Int) -> UInt {
+    return UInt(shiftLeft(Int(value), count))
+}
+
+func shiftLeft(value: UInt8, count: Int) -> UInt8 {
+    return UInt8(shiftLeft(UInt(value), count))
+}
+
+func shiftLeft(value: UInt16, count: Int) -> UInt16 {
+    return UInt16(shiftLeft(UInt(value), count))
+}
+
+func shiftLeft(value: UInt32, count: Int) -> UInt32 {
+    return UInt32(shiftLeft(UInt(value), count))
+}
+
+func shiftLeft(value: UInt64, count: Int) -> UInt64 {
+    return UInt64(shiftLeft(UInt(value), count))
+}
+
+func shiftLeft(value: Int8, count: Int) -> Int8 {
+    return Int8(shiftLeft(Int(value), count))
+}
+
+func shiftLeft(value: Int16, count: Int) -> Int16 {
+    return Int16(shiftLeft(Int(value), count))
+}
+
+func shiftLeft(value: Int32, count: Int) -> Int32 {
+    return Int32(shiftLeft(Int(value), count))
+}
+
+func shiftLeft(value: Int64, count: Int) -> Int64 {
+    return Int64(shiftLeft(Int(value), count))
+}
+

+ 3 - 18
CryptoSwift/IntExtension.swift

@@ -47,22 +47,7 @@ extension Int {
     
     /** Shift bits to the left. All bits are shifted (including sign bit) */
     private mutating func shiftLeft(count: Int) -> Int {
-        if (self == 0) {
-            return self;
-        }
-        
-        var bitsCount = sizeofValue(self) * 8
-        var shiftCount = Swift.min(count, bitsCount - 1)
-        var shiftedValue:Int = 0;
-        
-        for bitIdx in 0..<bitsCount {
-            // if bit is set then copy to result and shift left 1
-            var bit = 1 << bitIdx
-            if ((self & bit) == bit) {
-                shiftedValue = shiftedValue | (bit << shiftCount)
-            }
-        }
-        self = shiftedValue
+        self = CryptoSwift.shiftLeft(self, count)
         return self
     }
     
@@ -97,12 +82,12 @@ extension Int {
 // Left operator
 
 /** shift left and assign with bits truncation */
-func &<<= (inout lhs: Int, rhs: Int) {
+public func &<<= (inout lhs: Int, rhs: Int) {
     lhs.shiftLeft(rhs)
 }
 
 /** shift left with bits truncation */
-func &<< (lhs: Int, rhs: Int) -> Int {
+public func &<< (lhs: Int, rhs: Int) -> Int {
     var l = lhs;
     l.shiftLeft(rhs)
     return l

+ 1 - 1
CryptoSwift/Poly1305.swift

@@ -293,7 +293,7 @@ public class Poly1305 {
                 }
                 for j in (i+1)..<17 {
                     var v:UInt32 = UInt32(UInt16(context.h[j])) * UInt32(context.r[i + 17 - j])  // unsigned long v = (unsigned short)st->h[j] * st->r[i + 17 - j];
-                    v = ((v &<< 8) &+ (v &<< 6))
+                    v = ((v << 8) &+ (v << 6))
                     u = u &+ v
                 }
                 hr[i] = u

+ 8 - 2
CryptoSwift/UInt32Extension.swift

@@ -44,6 +44,12 @@ extension UInt32 {
                 shiftedValue = shiftedValue | (bit << shiftCount)
             }
         }
+        
+        if (shiftedValue != 0 && count >= bitsCount) {
+            // clear last bit that couldn't be shifted out of range
+            shiftedValue = shiftedValue & (~(1 << (bitsCount - 1)))
+        }
+
         self = shiftedValue
         return self
     }
@@ -78,12 +84,12 @@ extension UInt32 {
 }
 
 /** shift left and assign with bits truncation */
-func &<<= (inout lhs: UInt32, rhs: UInt32) {
+public func &<<= (inout lhs: UInt32, rhs: UInt32) {
     lhs.shiftLeft(rhs)
 }
 
 /** shift left with bits truncation */
-func &<< (lhs: UInt32, rhs: UInt32) -> UInt32 {
+public func &<< (lhs: UInt32, rhs: UInt32) -> UInt32 {
     var l = lhs;
     l.shiftLeft(rhs)
     return l

+ 24 - 0
CryptoSwiftTests/ExtensionsTest.swift

@@ -8,6 +8,7 @@
 
 import UIKit
 import XCTest
+import CryptoSwift
 
 class ExtensionsTest: XCTestCase {
 
@@ -45,4 +46,27 @@ class ExtensionsTest: XCTestCase {
         XCTAssertTrue(bytes.count == 16, "Invalid return type \(bytes.count)")
         XCTAssertTrue(bytes[14] == 4, "Invalid return type \(bytes.count)")
     }
+    
+    func testShiftLeft() {
+        // Unsigned
+        var i:UInt32 = 1
+        XCTAssert(i &<< 1 == 2, "shift left failed")
+        XCTAssert(i &<< 8 == 256, "shift left failed")
+        XCTAssert(i &<< 32 == 2147483648, "shift left failed")
+        XCTAssert(i &<< 33 == 2147483648, "shift left failed")
+
+        // Signed
+        var ii:Int = 21
+        XCTAssert(ii &<< 1 == ii << 1, "shift left failed")
+        XCTAssert(ii &<< 8 == ii << 8, "shift left failed")
+        XCTAssert(ii &<< ((sizeofValue(ii) * 8) - 1) == ii << ((sizeofValue(ii) * 8) - 1), "shift left failed")
+        XCTAssert(ii &<< ((sizeofValue(ii) * 8)) == 0, "shift left failed")
+        
+        var iii:UInt32 = 21
+        XCTAssert(iii &<< 1 == iii << 1, "shift left failed")
+        XCTAssert(iii &<< 8 == iii << 8, "shift left failed")
+        XCTAssert(iii &<< ((sizeofValue(iii) * 8) - 1) == iii << ((sizeofValue(iii) * 8) - 1), "shift left failed")
+        XCTAssert((iii &<< 32) == 0, "shift left failed")
+
+    }
 }