WSEngine.swift 8.8 KB


  1. //////////////////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // WSEngine.swift
  4. // Starscream
  5. //
  6. // Created by Dalton Cherry on 6/15/19
  7. // Copyright © 2019 Vluxe. All rights reserved.
  8. //
  9. // Licensed under the Apache License, Version 2.0 (the "License");
  10. // you may not use this file except in compliance with the License.
  11. // You may obtain a copy of the License at
  12. //
  13. // http://www.apache.org/licenses/LICENSE-2.0
  14. //
  15. // Unless required by applicable law or agreed to in writing, software
  16. // distributed under the License is distributed on an "AS IS" BASIS,
  17. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. // See the License for the specific language governing permissions and
  19. // limitations under the License.
  20. //
  21. //////////////////////////////////////////////////////////////////////////////////////////////////
  22. import Foundation
  23. public class WSEngine: Engine, TransportEventClient, FramerEventClient,
  24. FrameCollectorDelegate, HTTPHandlerDelegate {
  25. private let transport: Transport
  26. private let framer: Framer
  27. private let httpHandler: HTTPHandler
  28. private let compressionHandler: CompressionHandler?
  29. private let certPinner: CertificatePinning?
  30. private let headerChecker: HeaderValidator
  31. private var request: URLRequest!
  32. private let frameHandler = FrameCollector()
  33. private var didUpgrade = false
  34. private var secKeyValue = ""
  35. private let writeQueue = DispatchQueue(label: "com.vluxe.starscream.writequeue")
  36. private let mutex = DispatchSemaphore(value: 1)
  37. private var canSend = false
  38. private var isConnecting = false
  39. weak var delegate: EngineDelegate?
  40. public var respondToPingWithPong: Bool = true
  41. public init(transport: Transport,
  42. certPinner: CertificatePinning? = nil,
  43. headerValidator: HeaderValidator = FoundationSecurity(),
  44. httpHandler: HTTPHandler = FoundationHTTPHandler(),
  45. framer: Framer = WSFramer(),
  46. compressionHandler: CompressionHandler? = nil) {
  47. self.transport = transport
  48. self.framer = framer
  49. self.httpHandler = httpHandler
  50. self.certPinner = certPinner
  51. self.headerChecker = headerValidator
  52. self.compressionHandler = compressionHandler
  53. framer.updateCompression(supports: compressionHandler != nil)
  54. frameHandler.delegate = self
  55. }
  56. public func register(delegate: EngineDelegate) {
  57. self.delegate = delegate
  58. }
  59. public func start(request: URLRequest) {
  60. mutex.wait()
  61. let isConnecting = self.isConnecting
  62. let isConnected = canSend
  63. mutex.signal()
  64. if isConnecting || isConnected {
  65. return
  66. }
  67. self.request = request
  68. transport.register(delegate: self)
  69. framer.register(delegate: self)
  70. httpHandler.register(delegate: self)
  71. frameHandler.delegate = self
  72. guard let url = request.url else {
  73. return
  74. }
  75. mutex.wait()
  76. self.isConnecting = true
  77. mutex.signal()
  78. transport.connect(url: url, timeout: request.timeoutInterval, certificatePinning: certPinner)
  79. }
  80. public func stop(closeCode: UInt16 = CloseCode.normal.rawValue) {
  81. let capacity = MemoryLayout<UInt16>.size
  82. var pointer = [UInt8](repeating: 0, count: capacity)
  83. writeUint16(&pointer, offset: 0, value: closeCode)
  84. let payload = Data(bytes: pointer, count: MemoryLayout<UInt16>.size)
  85. write(data: payload, opcode: .connectionClose, completion: { [weak self] in
  86. self?.reset()
  87. self?.forceStop()
  88. })
  89. }
  90. public func forceStop() {
  91. mutex.wait()
  92. isConnecting = false
  93. mutex.signal()
  94. transport.disconnect()
  95. }
  96. public func write(string: String, completion: (() -> ())?) {
  97. let data = string.data(using: .utf8)!
  98. write(data: data, opcode: .textFrame, completion: completion)
  99. }
  100. public func write(data: Data, opcode: FrameOpCode, completion: (() -> ())?) {
  101. writeQueue.async { [weak self] in
  102. guard let s = self else { return }
  103. s.mutex.wait()
  104. let canWrite = s.canSend
  105. s.mutex.signal()
  106. if !canWrite {
  107. return
  108. }
  109. var isCompressed = false
  110. var sendData = data
  111. if let compressedData = s.compressionHandler?.compress(data: data) {
  112. sendData = compressedData
  113. isCompressed = true
  114. }
  115. let frameData = s.framer.createWriteFrame(opcode: opcode, payload: sendData, isCompressed: isCompressed)
  116. s.transport.write(data: frameData, completion: {_ in
  117. completion?()
  118. })
  119. }
  120. }
  121. // MARK: - TransportEventClient
  122. public func connectionChanged(state: ConnectionState) {
  123. switch state {
  124. case .connected:
  125. secKeyValue = HTTPWSHeader.generateWebSocketKey()
  126. let wsReq = HTTPWSHeader.createUpgrade(request: request, supportsCompression: framer.supportsCompression(), secKeyValue: secKeyValue)
  127. let data = httpHandler.convert(request: wsReq)
  128. transport.write(data: data, completion: {_ in })
  129. case .waiting:
  130. break
  131. case .failed(let error):
  132. handleError(error)
  133. case .viability(let isViable):
  134. broadcast(event: .viabilityChanged(isViable))
  135. case .shouldReconnect(let status):
  136. broadcast(event: .reconnectSuggested(status))
  137. case .receive(let data):
  138. if didUpgrade {
  139. framer.add(data: data)
  140. } else {
  141. let offset = httpHandler.parse(data: data)
  142. if offset > 0 {
  143. let extraData = data.subdata(in: offset..<data.endIndex)
  144. framer.add(data: extraData)
  145. }
  146. }
  147. case .cancelled:
  148. mutex.wait()
  149. isConnecting = false
  150. mutex.signal()
  151. broadcast(event: .cancelled)
  152. case .peerClosed:
  153. broadcast(event: .peerClosed)
  154. }
  155. }
  156. // MARK: - HTTPHandlerDelegate
  157. public func didReceiveHTTP(event: HTTPEvent) {
  158. switch event {
  159. case .success(let headers):
  160. if let error = headerChecker.validate(headers: headers, key: secKeyValue) {
  161. handleError(error)
  162. return
  163. }
  164. mutex.wait()
  165. isConnecting = false
  166. didUpgrade = true
  167. canSend = true
  168. mutex.signal()
  169. compressionHandler?.load(headers: headers)
  170. if let url = request.url {
  171. HTTPCookie.cookies(withResponseHeaderFields: headers, for: url).forEach {
  172. HTTPCookieStorage.shared.setCookie($0)
  173. }
  174. }
  175. broadcast(event: .connected(headers))
  176. case .failure(let error):
  177. handleError(error)
  178. }
  179. }
  180. // MARK: - FramerEventClient
  181. public func frameProcessed(event: FrameEvent) {
  182. switch event {
  183. case .frame(let frame):
  184. frameHandler.add(frame: frame)
  185. case .error(let error):
  186. handleError(error)
  187. }
  188. }
  189. // MARK: - FrameCollectorDelegate
  190. public func decompress(data: Data, isFinal: Bool) -> Data? {
  191. return compressionHandler?.decompress(data: data, isFinal: isFinal)
  192. }
  193. public func didForm(event: FrameCollector.Event) {
  194. switch event {
  195. case .text(let string):
  196. broadcast(event: .text(string))
  197. case .binary(let data):
  198. broadcast(event: .binary(data))
  199. case .pong(let data):
  200. broadcast(event: .pong(data))
  201. case .ping(let data):
  202. broadcast(event: .ping(data))
  203. if respondToPingWithPong {
  204. write(data: data ?? Data(), opcode: .pong, completion: nil)
  205. }
  206. case .closed(let reason, let code):
  207. broadcast(event: .disconnected(reason, code))
  208. stop(closeCode: code)
  209. case .error(let error):
  210. handleError(error)
  211. }
  212. }
  213. private func broadcast(event: WebSocketEvent) {
  214. delegate?.didReceive(event: event)
  215. }
  216. //This call can be coming from a lot of different queues/threads.
  217. //be aware of that when modifying shared variables
  218. private func handleError(_ error: Error?) {
  219. if let wsError = error as? WSError {
  220. stop(closeCode: wsError.code)
  221. } else {
  222. stop()
  223. }
  224. delegate?.didReceive(event: .error(error))
  225. }
  226. private func reset() {
  227. mutex.wait()
  228. isConnecting = false
  229. canSend = false
  230. didUpgrade = false
  231. mutex.signal()
  232. }
  233. }