QueueChannel.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. //===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
  9. #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
  10. #include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
  11. #include "llvm/Support/Error.h"
  12. #include <atomic>
  13. #include <condition_variable>
  14. #include <queue>
  15. namespace llvm {
  16. class QueueChannelError : public ErrorInfo<QueueChannelError> {
  17. public:
  18. static char ID;
  19. };
  20. class QueueChannelClosedError
  21. : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
  22. public:
  23. static char ID;
  24. std::error_code convertToErrorCode() const override {
  25. return inconvertibleErrorCode();
  26. }
  27. void log(raw_ostream &OS) const override {
  28. OS << "Queue closed";
  29. }
  30. };
  31. class Queue : public std::queue<char> {
  32. public:
  33. using ErrorInjector = std::function<Error()>;
  34. Queue()
  35. : ReadError([]() { return Error::success(); }),
  36. WriteError([]() { return Error::success(); }) {}
  37. Queue(const Queue&) = delete;
  38. Queue& operator=(const Queue&) = delete;
  39. Queue(Queue&&) = delete;
  40. Queue& operator=(Queue&&) = delete;
  41. std::mutex &getMutex() { return M; }
  42. std::condition_variable &getCondVar() { return CV; }
  43. Error checkReadError() { return ReadError(); }
  44. Error checkWriteError() { return WriteError(); }
  45. void setReadError(ErrorInjector NewReadError) {
  46. {
  47. std::lock_guard<std::mutex> Lock(M);
  48. ReadError = std::move(NewReadError);
  49. }
  50. CV.notify_one();
  51. }
  52. void setWriteError(ErrorInjector NewWriteError) {
  53. std::lock_guard<std::mutex> Lock(M);
  54. WriteError = std::move(NewWriteError);
  55. }
  56. private:
  57. std::mutex M;
  58. std::condition_variable CV;
  59. std::function<Error()> ReadError, WriteError;
  60. };
  61. class QueueChannel : public orc::rpc::RawByteChannel {
  62. public:
  63. QueueChannel(std::shared_ptr<Queue> InQueue,
  64. std::shared_ptr<Queue> OutQueue)
  65. : InQueue(InQueue), OutQueue(OutQueue) {}
  66. QueueChannel(const QueueChannel&) = delete;
  67. QueueChannel& operator=(const QueueChannel&) = delete;
  68. QueueChannel(QueueChannel&&) = delete;
  69. QueueChannel& operator=(QueueChannel&&) = delete;
  70. template <typename FunctionIdT, typename SequenceIdT>
  71. Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
  72. ++InFlightOutgoingMessages;
  73. return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo);
  74. }
  75. Error endSendMessage() {
  76. --InFlightOutgoingMessages;
  77. ++CompletedOutgoingMessages;
  78. return orc::rpc::RawByteChannel::endSendMessage();
  79. }
  80. template <typename FunctionIdT, typename SequenceNumberT>
  81. Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
  82. ++InFlightIncomingMessages;
  83. return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo);
  84. }
  85. Error endReceiveMessage() {
  86. --InFlightIncomingMessages;
  87. ++CompletedIncomingMessages;
  88. return orc::rpc::RawByteChannel::endReceiveMessage();
  89. }
  90. Error readBytes(char *Dst, unsigned Size) override {
  91. std::unique_lock<std::mutex> Lock(InQueue->getMutex());
  92. while (Size) {
  93. {
  94. Error Err = InQueue->checkReadError();
  95. while (!Err && InQueue->empty()) {
  96. InQueue->getCondVar().wait(Lock);
  97. Err = InQueue->checkReadError();
  98. }
  99. if (Err)
  100. return Err;
  101. }
  102. *Dst++ = InQueue->front();
  103. --Size;
  104. ++NumRead;
  105. InQueue->pop();
  106. }
  107. return Error::success();
  108. }
  109. Error appendBytes(const char *Src, unsigned Size) override {
  110. std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
  111. while (Size--) {
  112. if (Error Err = OutQueue->checkWriteError())
  113. return Err;
  114. OutQueue->push(*Src++);
  115. ++NumWritten;
  116. }
  117. OutQueue->getCondVar().notify_one();
  118. return Error::success();
  119. }
  120. Error send() override {
  121. ++SendCalls;
  122. return Error::success();
  123. }
  124. void close() {
  125. auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
  126. InQueue->setReadError(ChannelClosed);
  127. InQueue->setWriteError(ChannelClosed);
  128. OutQueue->setReadError(ChannelClosed);
  129. OutQueue->setWriteError(ChannelClosed);
  130. }
  131. uint64_t NumWritten = 0;
  132. uint64_t NumRead = 0;
  133. std::atomic<size_t> InFlightIncomingMessages{0};
  134. std::atomic<size_t> CompletedIncomingMessages{0};
  135. std::atomic<size_t> InFlightOutgoingMessages{0};
  136. std::atomic<size_t> CompletedOutgoingMessages{0};
  137. std::atomic<size_t> SendCalls{0};
  138. private:
  139. std::shared_ptr<Queue> InQueue;
  140. std::shared_ptr<Queue> OutQueue;
  141. };
  142. inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
  143. createPairedQueueChannels() {
  144. auto Q1 = std::make_shared<Queue>();
  145. auto Q2 = std::make_shared<Queue>();
  146. auto C1 = std::make_unique<QueueChannel>(Q1, Q2);
  147. auto C2 = std::make_unique<QueueChannel>(Q2, Q1);
  148. return std::make_pair(std::move(C1), std::move(C2));
  149. }
  150. }
  151. #endif