spdm-socket.c 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /* SPDX-License-Identifier: BSD-3-Clause */
  2. /*
  3. * QEMU SPDM socket support
  4. *
  5. * This is based on:
  6. * https://github.com/DMTF/spdm-emu/blob/07c0a838bcc1c6207c656ac75885c0603e344b6f/spdm_emu/spdm_emu_common/command.c
  7. * but has been re-written to match QEMU style
  8. *
  9. * Copyright (c) 2021, DMTF. All rights reserved.
  10. * Copyright (c) 2023. Western Digital Corporation or its affiliates.
  11. */
  12. #include "qemu/osdep.h"
  13. #include "system/spdm-socket.h"
  14. #include "qapi/error.h"
  15. static bool read_bytes(const int socket, uint8_t *buffer,
  16. size_t number_of_bytes)
  17. {
  18. ssize_t number_received = 0;
  19. ssize_t result;
  20. while (number_received < number_of_bytes) {
  21. result = recv(socket, buffer + number_received,
  22. number_of_bytes - number_received, 0);
  23. if (result <= 0) {
  24. return false;
  25. }
  26. number_received += result;
  27. }
  28. return true;
  29. }
  30. static bool read_data32(const int socket, uint32_t *data)
  31. {
  32. bool result;
  33. result = read_bytes(socket, (uint8_t *)data, sizeof(uint32_t));
  34. if (!result) {
  35. return result;
  36. }
  37. *data = ntohl(*data);
  38. return true;
  39. }
  40. static bool read_multiple_bytes(const int socket, uint8_t *buffer,
  41. uint32_t *bytes_received,
  42. uint32_t max_buffer_length)
  43. {
  44. uint32_t length;
  45. bool result;
  46. result = read_data32(socket, &length);
  47. if (!result) {
  48. return result;
  49. }
  50. if (length > max_buffer_length) {
  51. return false;
  52. }
  53. if (bytes_received) {
  54. *bytes_received = length;
  55. }
  56. if (length == 0) {
  57. return true;
  58. }
  59. return read_bytes(socket, buffer, length);
  60. }
  61. static bool receive_platform_data(const int socket,
  62. uint32_t transport_type,
  63. uint32_t *command,
  64. uint8_t *receive_buffer,
  65. uint32_t *bytes_to_receive)
  66. {
  67. bool result;
  68. uint32_t response;
  69. uint32_t bytes_received;
  70. result = read_data32(socket, &response);
  71. if (!result) {
  72. return result;
  73. }
  74. *command = response;
  75. result = read_data32(socket, &transport_type);
  76. if (!result) {
  77. return result;
  78. }
  79. bytes_received = 0;
  80. result = read_multiple_bytes(socket, receive_buffer, &bytes_received,
  81. *bytes_to_receive);
  82. if (!result) {
  83. return result;
  84. }
  85. *bytes_to_receive = bytes_received;
  86. return result;
  87. }
  88. static bool write_bytes(const int socket, const uint8_t *buffer,
  89. uint32_t number_of_bytes)
  90. {
  91. ssize_t number_sent = 0;
  92. ssize_t result;
  93. while (number_sent < number_of_bytes) {
  94. result = send(socket, buffer + number_sent,
  95. number_of_bytes - number_sent, 0);
  96. if (result == -1) {
  97. return false;
  98. }
  99. number_sent += result;
  100. }
  101. return true;
  102. }
  103. static bool write_data32(const int socket, uint32_t data)
  104. {
  105. data = htonl(data);
  106. return write_bytes(socket, (uint8_t *)&data, sizeof(uint32_t));
  107. }
  108. static bool write_multiple_bytes(const int socket, const uint8_t *buffer,
  109. uint32_t bytes_to_send)
  110. {
  111. bool result;
  112. result = write_data32(socket, bytes_to_send);
  113. if (!result) {
  114. return result;
  115. }
  116. return write_bytes(socket, buffer, bytes_to_send);
  117. }
  118. static bool send_platform_data(const int socket,
  119. uint32_t transport_type, uint32_t command,
  120. const uint8_t *send_buffer, size_t bytes_to_send)
  121. {
  122. bool result;
  123. result = write_data32(socket, command);
  124. if (!result) {
  125. return result;
  126. }
  127. result = write_data32(socket, transport_type);
  128. if (!result) {
  129. return result;
  130. }
  131. return write_multiple_bytes(socket, send_buffer, bytes_to_send);
  132. }
  133. int spdm_socket_connect(uint16_t port, Error **errp)
  134. {
  135. int client_socket;
  136. struct sockaddr_in server_addr;
  137. client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
  138. if (client_socket < 0) {
  139. error_setg(errp, "cannot create socket: %s", strerror(errno));
  140. return -1;
  141. }
  142. memset((char *)&server_addr, 0, sizeof(server_addr));
  143. server_addr.sin_family = AF_INET;
  144. server_addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
  145. server_addr.sin_port = htons(port);
  146. if (connect(client_socket, (struct sockaddr *)&server_addr,
  147. sizeof(server_addr)) < 0) {
  148. error_setg(errp, "cannot connect: %s", strerror(errno));
  149. close(client_socket);
  150. return -1;
  151. }
  152. return client_socket;
  153. }
  154. uint32_t spdm_socket_rsp(const int socket, uint32_t transport_type,
  155. void *req, uint32_t req_len,
  156. void *rsp, uint32_t rsp_len)
  157. {
  158. uint32_t command;
  159. bool result;
  160. result = send_platform_data(socket, transport_type,
  161. SPDM_SOCKET_COMMAND_NORMAL,
  162. req, req_len);
  163. if (!result) {
  164. return 0;
  165. }
  166. result = receive_platform_data(socket, transport_type, &command,
  167. (uint8_t *)rsp, &rsp_len);
  168. if (!result) {
  169. return 0;
  170. }
  171. assert(command != 0);
  172. return rsp_len;
  173. }
  174. void spdm_socket_close(const int socket, uint32_t transport_type)
  175. {
  176. send_platform_data(socket, transport_type,
  177. SPDM_SOCKET_COMMAND_SHUTDOWN, NULL, 0);
  178. }