123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- //===----------------------- Queue.h - RPC Queue ------------------*-c++-*-===//
- //
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
- // See https://llvm.org/LICENSE.txt for license information.
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
- //
- //===----------------------------------------------------------------------===//
- #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
- #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_QUEUECHANNEL_H
- #include "llvm/ExecutionEngine/Orc/RawByteChannel.h"
- #include "llvm/Support/Error.h"
- #include <atomic>
- #include <condition_variable>
- #include <queue>
- namespace llvm {
- class QueueChannelError : public ErrorInfo<QueueChannelError> {
- public:
- static char ID;
- };
- class QueueChannelClosedError
- : public ErrorInfo<QueueChannelClosedError, QueueChannelError> {
- public:
- static char ID;
- std::error_code convertToErrorCode() const override {
- return inconvertibleErrorCode();
- }
- void log(raw_ostream &OS) const override {
- OS << "Queue closed";
- }
- };
- class Queue : public std::queue<char> {
- public:
- using ErrorInjector = std::function<Error()>;
- Queue()
- : ReadError([]() { return Error::success(); }),
- WriteError([]() { return Error::success(); }) {}
- Queue(const Queue&) = delete;
- Queue& operator=(const Queue&) = delete;
- Queue(Queue&&) = delete;
- Queue& operator=(Queue&&) = delete;
- std::mutex &getMutex() { return M; }
- std::condition_variable &getCondVar() { return CV; }
- Error checkReadError() { return ReadError(); }
- Error checkWriteError() { return WriteError(); }
- void setReadError(ErrorInjector NewReadError) {
- {
- std::lock_guard<std::mutex> Lock(M);
- ReadError = std::move(NewReadError);
- }
- CV.notify_one();
- }
- void setWriteError(ErrorInjector NewWriteError) {
- std::lock_guard<std::mutex> Lock(M);
- WriteError = std::move(NewWriteError);
- }
- private:
- std::mutex M;
- std::condition_variable CV;
- std::function<Error()> ReadError, WriteError;
- };
- class QueueChannel : public orc::rpc::RawByteChannel {
- public:
- QueueChannel(std::shared_ptr<Queue> InQueue,
- std::shared_ptr<Queue> OutQueue)
- : InQueue(InQueue), OutQueue(OutQueue) {}
- QueueChannel(const QueueChannel&) = delete;
- QueueChannel& operator=(const QueueChannel&) = delete;
- QueueChannel(QueueChannel&&) = delete;
- QueueChannel& operator=(QueueChannel&&) = delete;
- template <typename FunctionIdT, typename SequenceIdT>
- Error startSendMessage(const FunctionIdT &FnId, const SequenceIdT &SeqNo) {
- ++InFlightOutgoingMessages;
- return orc::rpc::RawByteChannel::startSendMessage(FnId, SeqNo);
- }
- Error endSendMessage() {
- --InFlightOutgoingMessages;
- ++CompletedOutgoingMessages;
- return orc::rpc::RawByteChannel::endSendMessage();
- }
- template <typename FunctionIdT, typename SequenceNumberT>
- Error startReceiveMessage(FunctionIdT &FnId, SequenceNumberT &SeqNo) {
- ++InFlightIncomingMessages;
- return orc::rpc::RawByteChannel::startReceiveMessage(FnId, SeqNo);
- }
- Error endReceiveMessage() {
- --InFlightIncomingMessages;
- ++CompletedIncomingMessages;
- return orc::rpc::RawByteChannel::endReceiveMessage();
- }
- Error readBytes(char *Dst, unsigned Size) override {
- std::unique_lock<std::mutex> Lock(InQueue->getMutex());
- while (Size) {
- {
- Error Err = InQueue->checkReadError();
- while (!Err && InQueue->empty()) {
- InQueue->getCondVar().wait(Lock);
- Err = InQueue->checkReadError();
- }
- if (Err)
- return Err;
- }
- *Dst++ = InQueue->front();
- --Size;
- ++NumRead;
- InQueue->pop();
- }
- return Error::success();
- }
- Error appendBytes(const char *Src, unsigned Size) override {
- std::unique_lock<std::mutex> Lock(OutQueue->getMutex());
- while (Size--) {
- if (Error Err = OutQueue->checkWriteError())
- return Err;
- OutQueue->push(*Src++);
- ++NumWritten;
- }
- OutQueue->getCondVar().notify_one();
- return Error::success();
- }
- Error send() override {
- ++SendCalls;
- return Error::success();
- }
- void close() {
- auto ChannelClosed = []() { return make_error<QueueChannelClosedError>(); };
- InQueue->setReadError(ChannelClosed);
- InQueue->setWriteError(ChannelClosed);
- OutQueue->setReadError(ChannelClosed);
- OutQueue->setWriteError(ChannelClosed);
- }
- uint64_t NumWritten = 0;
- uint64_t NumRead = 0;
- std::atomic<size_t> InFlightIncomingMessages{0};
- std::atomic<size_t> CompletedIncomingMessages{0};
- std::atomic<size_t> InFlightOutgoingMessages{0};
- std::atomic<size_t> CompletedOutgoingMessages{0};
- std::atomic<size_t> SendCalls{0};
- private:
- std::shared_ptr<Queue> InQueue;
- std::shared_ptr<Queue> OutQueue;
- };
- inline std::pair<std::unique_ptr<QueueChannel>, std::unique_ptr<QueueChannel>>
- createPairedQueueChannels() {
- auto Q1 = std::make_shared<Queue>();
- auto Q2 = std::make_shared<Queue>();
- auto C1 = std::make_unique<QueueChannel>(Q1, Q2);
- auto C2 = std::make_unique<QueueChannel>(Q2, Q1);
- return std::make_pair(std::move(C1), std::move(C2));
- }
- }
- #endif
|