WSCompression.swift 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. //////////////////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // WSCompression.swift
  4. //
  5. // Created by Joseph Ross on 7/16/14.
  6. // Copyright © 2017 Joseph Ross & Vluxe. All rights reserved.
  7. //
  8. // Licensed under the Apache License, Version 2.0 (the "License");
  9. // you may not use this file except in compliance with the License.
  10. // You may obtain a copy of the License at
  11. //
  12. // http://www.apache.org/licenses/LICENSE-2.0
  13. //
  14. // Unless required by applicable law or agreed to in writing, software
  15. // distributed under the License is distributed on an "AS IS" BASIS,
  16. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. // See the License for the specific language governing permissions and
  18. // limitations under the License.
  19. //
  20. //////////////////////////////////////////////////////////////////////////////////////////////////
  21. //////////////////////////////////////////////////////////////////////////////////////////////////
  22. //
  23. // Compression implementation is implemented in conformance with RFC 7692 Compression Extensions
  24. // for WebSocket: https://tools.ietf.org/html/rfc7692
  25. //
  26. //////////////////////////////////////////////////////////////////////////////////////////////////
  27. import Foundation
  28. import zlib
  29. public class WSCompression: CompressionHandler {
  30. let headerWSExtensionName = "Sec-WebSocket-Extensions"
  31. var decompressor: Decompressor?
  32. var compressor: Compressor?
  33. var decompressorTakeOver = false
  34. var compressorTakeOver = false
  35. public init() {
  36. }
  37. public func load(headers: [String: String]) {
  38. guard let extensionHeader = headers[headerWSExtensionName] else { return }
  39. decompressorTakeOver = false
  40. compressorTakeOver = false
  41. // assume defaults unless the headers say otherwise
  42. compressor = Compressor(windowBits: 15)
  43. decompressor = Decompressor(windowBits: 15)
  44. let parts = extensionHeader.components(separatedBy: ";")
  45. for p in parts {
  46. let part = p.trimmingCharacters(in: .whitespaces)
  47. if part.hasPrefix("server_max_window_bits=") {
  48. let valString = part.components(separatedBy: "=")[1]
  49. if let val = Int(valString.trimmingCharacters(in: .whitespaces)) {
  50. decompressor = Decompressor(windowBits: val)
  51. }
  52. } else if part.hasPrefix("client_max_window_bits=") {
  53. let valString = part.components(separatedBy: "=")[1]
  54. if let val = Int(valString.trimmingCharacters(in: .whitespaces)) {
  55. compressor = Compressor(windowBits: val)
  56. }
  57. } else if part == "client_no_context_takeover" {
  58. compressorTakeOver = true
  59. } else if part == "server_no_context_takeover" {
  60. decompressorTakeOver = true
  61. }
  62. }
  63. }
  64. public func decompress(data: Data, isFinal: Bool) -> Data? {
  65. guard let decompressor = decompressor else { return nil }
  66. do {
  67. let decompressedData = try decompressor.decompress(data, finish: isFinal)
  68. if decompressorTakeOver {
  69. try decompressor.reset()
  70. }
  71. return decompressedData
  72. } catch {
  73. //do nothing with the error for now
  74. }
  75. return nil
  76. }
  77. public func compress(data: Data) -> Data? {
  78. guard let compressor = compressor else { return nil }
  79. do {
  80. let compressedData = try compressor.compress(data)
  81. if compressorTakeOver {
  82. try compressor.reset()
  83. }
  84. return compressedData
  85. } catch {
  86. //do nothing with the error for now
  87. }
  88. return nil
  89. }
  90. }
  91. class Decompressor {
  92. private var strm = z_stream()
  93. private var buffer = [UInt8](repeating: 0, count: 0x2000)
  94. private var inflateInitialized = false
  95. private let windowBits: Int
  96. init?(windowBits: Int) {
  97. self.windowBits = windowBits
  98. guard initInflate() else { return nil }
  99. }
  100. private func initInflate() -> Bool {
  101. if Z_OK == inflateInit2_(&strm, -CInt(windowBits),
  102. ZLIB_VERSION, CInt(MemoryLayout<z_stream>.size))
  103. {
  104. inflateInitialized = true
  105. return true
  106. }
  107. return false
  108. }
  109. func reset() throws {
  110. teardownInflate()
  111. guard initInflate() else { throw WSError(type: .compressionError, message: "Error for decompressor on reset", code: 0) }
  112. }
  113. func decompress(_ data: Data, finish: Bool) throws -> Data {
  114. return try data.withUnsafeBytes { (bytes: UnsafePointer<UInt8>) -> Data in
  115. return try decompress(bytes: bytes, count: data.count, finish: finish)
  116. }
  117. }
  118. func decompress(bytes: UnsafePointer<UInt8>, count: Int, finish: Bool) throws -> Data {
  119. var decompressed = Data()
  120. try decompress(bytes: bytes, count: count, out: &decompressed)
  121. if finish {
  122. let tail:[UInt8] = [0x00, 0x00, 0xFF, 0xFF]
  123. try decompress(bytes: tail, count: tail.count, out: &decompressed)
  124. }
  125. return decompressed
  126. }
  127. private func decompress(bytes: UnsafePointer<UInt8>, count: Int, out: inout Data) throws {
  128. var res: CInt = 0
  129. strm.next_in = UnsafeMutablePointer<UInt8>(mutating: bytes)
  130. strm.avail_in = CUnsignedInt(count)
  131. repeat {
  132. buffer.withUnsafeMutableBytes { (bufferPtr) in
  133. strm.next_out = bufferPtr.bindMemory(to: UInt8.self).baseAddress
  134. strm.avail_out = CUnsignedInt(bufferPtr.count)
  135. res = inflate(&strm, 0)
  136. }
  137. let byteCount = buffer.count - Int(strm.avail_out)
  138. out.append(buffer, count: byteCount)
  139. } while res == Z_OK && strm.avail_out == 0
  140. guard (res == Z_OK && strm.avail_out > 0)
  141. || (res == Z_BUF_ERROR && Int(strm.avail_out) == buffer.count)
  142. else {
  143. throw WSError(type: .compressionError, message: "Error on decompressing", code: 0)
  144. }
  145. }
  146. private func teardownInflate() {
  147. if inflateInitialized, Z_OK == inflateEnd(&strm) {
  148. inflateInitialized = false
  149. }
  150. }
  151. deinit {
  152. teardownInflate()
  153. }
  154. }
  155. class Compressor {
  156. private var strm = z_stream()
  157. private var buffer = [UInt8](repeating: 0, count: 0x2000)
  158. private var deflateInitialized = false
  159. private let windowBits: Int
  160. init?(windowBits: Int) {
  161. self.windowBits = windowBits
  162. guard initDeflate() else { return nil }
  163. }
  164. private func initDeflate() -> Bool {
  165. if Z_OK == deflateInit2_(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED,
  166. -CInt(windowBits), 8, Z_DEFAULT_STRATEGY,
  167. ZLIB_VERSION, CInt(MemoryLayout<z_stream>.size))
  168. {
  169. deflateInitialized = true
  170. return true
  171. }
  172. return false
  173. }
  174. func reset() throws {
  175. teardownDeflate()
  176. guard initDeflate() else { throw WSError(type: .compressionError, message: "Error for compressor on reset", code: 0) }
  177. }
  178. func compress(_ data: Data) throws -> Data {
  179. guard !data.isEmpty else {
  180. // For example, PONG has no content
  181. return data
  182. }
  183. var compressed = Data()
  184. var res: CInt = 0
  185. data.withUnsafeBytes { (ptr:UnsafePointer<UInt8>) -> Void in
  186. strm.next_in = UnsafeMutablePointer<UInt8>(mutating: ptr)
  187. strm.avail_in = CUnsignedInt(data.count)
  188. repeat {
  189. buffer.withUnsafeMutableBytes { (bufferPtr) in
  190. strm.next_out = bufferPtr.bindMemory(to: UInt8.self).baseAddress
  191. strm.avail_out = CUnsignedInt(bufferPtr.count)
  192. res = deflate(&strm, Z_SYNC_FLUSH)
  193. }
  194. let byteCount = buffer.count - Int(strm.avail_out)
  195. compressed.append(buffer, count: byteCount)
  196. }
  197. while res == Z_OK && strm.avail_out == 0
  198. }
  199. guard res == Z_OK && strm.avail_out > 0
  200. || (res == Z_BUF_ERROR && Int(strm.avail_out) == buffer.count)
  201. else {
  202. throw WSError(type: .compressionError, message: "Error on compressing", code: 0)
  203. }
  204. compressed.removeLast(4)
  205. return compressed
  206. }
  207. private func teardownDeflate() {
  208. if deflateInitialized, Z_OK == deflateEnd(&strm) {
  209. deflateInitialized = false
  210. }
  211. }
  212. deinit {
  213. teardownDeflate()
  214. }
  215. }