浏览代码

Improve type-safety on sending frames and remove extraneous class checks. (#438)

Nikita Lutsenko 9 年之前
父节点
当前提交
f2e6387948
共有 1 个文件被更改,包括 31 次插入38 次删除
  1. 31 38
      SocketRocket/SRWebSocket.m

+ 31 - 38
SocketRocket/SRWebSocket.m

@@ -52,7 +52,8 @@ static size_t SRDefaultBufferSize(void) {
     return size;
 }
 
-typedef enum  {
+typedef NS_ENUM(NSInteger, SROpCode)
+{
     SROpCodeTextFrame = 0x1,
     SROpCodeBinaryFrame = 0x2,
     // 3-7 reserved.
@@ -60,7 +61,7 @@ typedef enum  {
     SROpCodePing = 0x9,
     SROpCodePong = 0xA,
     // B-F reserved.
-} SROpCode;
+};
 
 typedef struct {
     BOOL fin;
@@ -803,7 +804,7 @@ static inline BOOL closeCodeIsValid(int closeCode) {
     [self _pumpWriting];
 }
 
-- (void)_handleFrameWithData:(NSData *)frameData opCode:(NSInteger)opcode;
+- (void)_handleFrameWithData:(NSData *)frameData opCode:(SROpCode)opcode
 {
     //frameData will be copied before passing to handlers
     //otherwise there can be misbehaviours when value at the pointer is changed
@@ -1360,27 +1361,25 @@ static const char CRLFCRLFBytes[] = {'\r', '\n', '\r', '\n'};
 
 static const size_t SRFrameHeaderOverhead = 32;
 
-- (void)_sendFrameWithOpcode:(SROpCode)opcode data:(id)data;
+- (void)_sendFrameWithOpcode:(SROpCode)opCode data:(NSData *)data
 {
     [self assertOnWorkQueue];
 
-    if (nil == data) {
+    if (!data) {
         return;
     }
 
-    NSAssert([data isKindOfClass:[NSData class]] || [data isKindOfClass:[NSString class]], @"NSString or NSData");
-
-    size_t payloadLength = [data isKindOfClass:[NSString class]] ? [(NSString *)data lengthOfBytesUsingEncoding:NSUTF8StringEncoding] : [data length];
+    size_t payloadLength = data.length;
 
-    NSMutableData *frame = [[NSMutableData alloc] initWithLength:payloadLength + SRFrameHeaderOverhead];
-    if (!frame) {
+    NSMutableData *frameData = [[NSMutableData alloc] initWithLength:payloadLength + SRFrameHeaderOverhead];
+    if (!frameData) {
         [self closeWithCode:SRStatusCodeMessageTooBig reason:@"Message too big"];
         return;
     }
-    uint8_t *frame_buffer = (uint8_t *)[frame mutableBytes];
+    uint8_t *frameBuffer = (uint8_t *)frameData.mutableBytes;
 
     // set fin
-    frame_buffer[0] = SRFinMask | opcode;
+    frameBuffer[0] = SRFinMask | opCode;
 
     BOOL useMask = YES;
 #ifdef NOMASK
@@ -1389,67 +1388,61 @@ static const size_t SRFrameHeaderOverhead = 32;
 
     if (useMask) {
         // set the mask and header
-        frame_buffer[1] |= SRMaskMask;
+        frameBuffer[1] |= SRMaskMask;
     }
 
-    size_t frame_buffer_size = 2;
-
-    const uint8_t *unmasked_payload = NULL;
-    if ([data isKindOfClass:[NSData class]]) {
-        unmasked_payload = (uint8_t *)[data bytes];
-    } else if ([data isKindOfClass:[NSString class]]) {
-        unmasked_payload =  (const uint8_t *)[data UTF8String];
-    } else {
-        return;
-    }
+    size_t frameBufferSize = 2;
 
     if (payloadLength < 126) {
-        frame_buffer[1] |= payloadLength;
+        frameBuffer[1] |= payloadLength;
     } else {
         uint64_t declaredPayloadLength = 0;
         size_t declaredPayloadLengthSize = 0;
 
         if (payloadLength <= UINT16_MAX) {
-            frame_buffer[1] |= 126;
+            frameBuffer[1] |= 126;
 
             declaredPayloadLength = CFSwapInt16BigToHost((uint16_t)payloadLength);
             declaredPayloadLengthSize = sizeof(uint16_t);
         } else {
-            frame_buffer[1] |= 127;
+            frameBuffer[1] |= 127;
 
             declaredPayloadLength = CFSwapInt64BigToHost((uint64_t)payloadLength);
             declaredPayloadLengthSize = sizeof(uint64_t);
         }
 
-        memcpy((frame_buffer + frame_buffer_size), &declaredPayloadLength, declaredPayloadLengthSize);
-        frame_buffer_size += declaredPayloadLengthSize;
+        memcpy((frameBuffer + frameBufferSize), &declaredPayloadLength, declaredPayloadLengthSize);
+        frameBufferSize += declaredPayloadLengthSize;
     }
 
+    const uint8_t *unmaskedPayloadBuffer = (uint8_t *)data.bytes;
     if (!useMask) {
         for (size_t i = 0; i < payloadLength; i++) {
-            frame_buffer[frame_buffer_size] = unmasked_payload[i];
-            frame_buffer_size += 1;
+            frameBuffer[frameBufferSize] = unmaskedPayloadBuffer[i];
+            frameBufferSize += 1;
         }
     } else {
-        uint8_t *mask_key = frame_buffer + frame_buffer_size;
-        int result = SecRandomCopyBytes(kSecRandomDefault, sizeof(uint32_t), (uint8_t *)mask_key);
+        uint8_t *maskKey = frameBuffer + frameBufferSize;
+
+        size_t randomBytesSize = sizeof(uint32_t);
+        int result = SecRandomCopyBytes(kSecRandomDefault, randomBytesSize, (uint8_t *)maskKey);
         if (result != 0) {
             //TODO: (nlutsenko) Check if there was an error.
         }
 
-        frame_buffer_size += sizeof(uint32_t);
+        frameBufferSize += randomBytesSize;
 
         // TODO: could probably optimize this with SIMD
         for (size_t i = 0; i < payloadLength; i++) {
-            frame_buffer[frame_buffer_size] = unmasked_payload[i] ^ mask_key[i % sizeof(uint32_t)];
-            frame_buffer_size += 1;
+            frameBuffer[frameBufferSize] = unmaskedPayloadBuffer[i] ^ maskKey[i % randomBytesSize];
+            frameBufferSize += 1;
         }
     }
 
-    assert(frame_buffer_size <= [frame length]);
-    frame.length = frame_buffer_size;
+    assert(frameBufferSize <= frameData.length);
+    frameData.length = frameBufferSize;
 
-    [self _writeData:frame];
+    [self _writeData:frameData];
 }
 
 - (void)stream:(NSStream *)aStream handleEvent:(NSStreamEvent)eventCode