OrcTestCommon.h 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. //===------ OrcTestCommon.h - Utilities for Orc Unit Tests ------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // Common utilities for the Orc unit tests.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #ifndef LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H
  13. #define LLVM_UNITTESTS_EXECUTIONENGINE_ORC_ORCTESTCOMMON_H
  14. #include "llvm/ExecutionEngine/ExecutionEngine.h"
  15. #include "llvm/ExecutionEngine/JITSymbol.h"
  16. #include "llvm/ExecutionEngine/Orc/IndirectionUtils.h"
  17. #include "llvm/IR/Function.h"
  18. #include "llvm/IR/IRBuilder.h"
  19. #include "llvm/IR/LLVMContext.h"
  20. #include "llvm/IR/Module.h"
  21. #include "llvm/Object/ObjectFile.h"
  22. #include "llvm/Support/TargetRegistry.h"
  23. #include "llvm/Support/TargetSelect.h"
  24. #include "gtest/gtest.h"
  25. #include <memory>
  26. namespace llvm {
  27. namespace orc {
  28. // CoreAPIsStandardTest that saves a bunch of boilerplate by providing the
  29. // following:
  30. //
  31. // (1) ES -- An ExecutionSession
  32. // (2) Foo, Bar, Baz, Qux -- SymbolStringPtrs for strings "foo", "bar", "baz",
  33. // and "qux" respectively.
  34. // (3) FooAddr, BarAddr, BazAddr, QuxAddr -- Dummy addresses. Guaranteed
  35. // distinct and non-null.
  36. // (4) FooSym, BarSym, BazSym, QuxSym -- JITEvaluatedSymbols with FooAddr,
  37. // BarAddr, BazAddr, and QuxAddr respectively. All with default strong,
  38. // linkage and non-hidden visibility.
  39. // (5) V -- A JITDylib associated with ES.
  40. class CoreAPIsBasedStandardTest : public testing::Test {
  41. protected:
  42. std::shared_ptr<SymbolStringPool> SSP = std::make_shared<SymbolStringPool>();
  43. ExecutionSession ES{SSP};
  44. JITDylib &JD = ES.createJITDylib("JD");
  45. SymbolStringPtr Foo = ES.intern("foo");
  46. SymbolStringPtr Bar = ES.intern("bar");
  47. SymbolStringPtr Baz = ES.intern("baz");
  48. SymbolStringPtr Qux = ES.intern("qux");
  49. static const JITTargetAddress FooAddr = 1U;
  50. static const JITTargetAddress BarAddr = 2U;
  51. static const JITTargetAddress BazAddr = 3U;
  52. static const JITTargetAddress QuxAddr = 4U;
  53. JITEvaluatedSymbol FooSym =
  54. JITEvaluatedSymbol(FooAddr, JITSymbolFlags::Exported);
  55. JITEvaluatedSymbol BarSym =
  56. JITEvaluatedSymbol(BarAddr, JITSymbolFlags::Exported);
  57. JITEvaluatedSymbol BazSym =
  58. JITEvaluatedSymbol(BazAddr, JITSymbolFlags::Exported);
  59. JITEvaluatedSymbol QuxSym =
  60. JITEvaluatedSymbol(QuxAddr, JITSymbolFlags::Exported);
  61. };
  62. } // end namespace orc
  63. class OrcNativeTarget {
  64. public:
  65. static void initialize() {
  66. if (!NativeTargetInitialized) {
  67. InitializeNativeTarget();
  68. InitializeNativeTargetAsmParser();
  69. InitializeNativeTargetAsmPrinter();
  70. NativeTargetInitialized = true;
  71. }
  72. }
  73. private:
  74. static bool NativeTargetInitialized;
  75. };
  76. class SimpleMaterializationUnit : public orc::MaterializationUnit {
  77. public:
  78. using MaterializeFunction =
  79. std::function<void(orc::MaterializationResponsibility)>;
  80. using DiscardFunction =
  81. std::function<void(const orc::JITDylib &, orc::SymbolStringPtr)>;
  82. using DestructorFunction = std::function<void()>;
  83. SimpleMaterializationUnit(
  84. orc::SymbolFlagsMap SymbolFlags, MaterializeFunction Materialize,
  85. DiscardFunction Discard = DiscardFunction(),
  86. DestructorFunction Destructor = DestructorFunction())
  87. : MaterializationUnit(std::move(SymbolFlags), orc::VModuleKey()),
  88. Materialize(std::move(Materialize)), Discard(std::move(Discard)),
  89. Destructor(std::move(Destructor)) {}
  90. ~SimpleMaterializationUnit() override {
  91. if (Destructor)
  92. Destructor();
  93. }
  94. StringRef getName() const override { return "<Simple>"; }
  95. void materialize(orc::MaterializationResponsibility R) override {
  96. Materialize(std::move(R));
  97. }
  98. void discard(const orc::JITDylib &JD,
  99. const orc::SymbolStringPtr &Name) override {
  100. if (Discard)
  101. Discard(JD, std::move(Name));
  102. else
  103. llvm_unreachable("Discard not supported");
  104. }
  105. private:
  106. MaterializeFunction Materialize;
  107. DiscardFunction Discard;
  108. DestructorFunction Destructor;
  109. };
  110. // Base class for Orc tests that will execute code.
  111. class OrcExecutionTest {
  112. public:
  113. OrcExecutionTest() {
  114. // Initialize the native target if it hasn't been done already.
  115. OrcNativeTarget::initialize();
  116. // Try to select a TargetMachine for the host.
  117. TM.reset(EngineBuilder().selectTarget());
  118. if (TM) {
  119. // If we found a TargetMachine, check that it's one that Orc supports.
  120. const Triple& TT = TM->getTargetTriple();
  121. // Bail out for windows platforms. We do not support these yet.
  122. if ((TT.getArch() != Triple::x86_64 && TT.getArch() != Triple::x86) ||
  123. TT.isOSWindows())
  124. return;
  125. // Target can JIT?
  126. SupportsJIT = TM->getTarget().hasJIT();
  127. // Use ability to create callback manager to detect whether Orc
  128. // has indirection support on this platform. This way the test
  129. // and Orc code do not get out of sync.
  130. SupportsIndirection = !!orc::createLocalCompileCallbackManager(TT, ES, 0);
  131. }
  132. };
  133. protected:
  134. orc::ExecutionSession ES;
  135. LLVMContext Context;
  136. std::unique_ptr<TargetMachine> TM;
  137. bool SupportsJIT = false;
  138. bool SupportsIndirection = false;
  139. };
  140. class ModuleBuilder {
  141. public:
  142. ModuleBuilder(LLVMContext &Context, StringRef Triple,
  143. StringRef Name);
  144. Function *createFunctionDecl(FunctionType *FTy, StringRef Name) {
  145. return Function::Create(FTy, GlobalValue::ExternalLinkage, Name, M.get());
  146. }
  147. Module* getModule() { return M.get(); }
  148. const Module* getModule() const { return M.get(); }
  149. std::unique_ptr<Module> takeModule() { return std::move(M); }
  150. private:
  151. std::unique_ptr<Module> M;
  152. };
  153. // Dummy struct type.
  154. struct DummyStruct {
  155. int X[256];
  156. };
  157. inline StructType *getDummyStructTy(LLVMContext &Context) {
  158. return StructType::get(ArrayType::get(Type::getInt32Ty(Context), 256));
  159. }
  160. template <typename HandleT, typename ModuleT>
  161. class MockBaseLayer {
  162. public:
  163. using ModuleHandleT = HandleT;
  164. using AddModuleSignature =
  165. Expected<ModuleHandleT>(ModuleT M,
  166. std::shared_ptr<JITSymbolResolver> R);
  167. using RemoveModuleSignature = Error(ModuleHandleT H);
  168. using FindSymbolSignature = JITSymbol(const std::string &Name,
  169. bool ExportedSymbolsOnly);
  170. using FindSymbolInSignature = JITSymbol(ModuleHandleT H,
  171. const std::string &Name,
  172. bool ExportedSymbolsONly);
  173. using EmitAndFinalizeSignature = Error(ModuleHandleT H);
  174. std::function<AddModuleSignature> addModuleImpl;
  175. std::function<RemoveModuleSignature> removeModuleImpl;
  176. std::function<FindSymbolSignature> findSymbolImpl;
  177. std::function<FindSymbolInSignature> findSymbolInImpl;
  178. std::function<EmitAndFinalizeSignature> emitAndFinalizeImpl;
  179. Expected<ModuleHandleT> addModule(ModuleT M,
  180. std::shared_ptr<JITSymbolResolver> R) {
  181. assert(addModuleImpl &&
  182. "addModule called, but no mock implementation was provided");
  183. return addModuleImpl(std::move(M), std::move(R));
  184. }
  185. Error removeModule(ModuleHandleT H) {
  186. assert(removeModuleImpl &&
  187. "removeModule called, but no mock implementation was provided");
  188. return removeModuleImpl(H);
  189. }
  190. JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) {
  191. assert(findSymbolImpl &&
  192. "findSymbol called, but no mock implementation was provided");
  193. return findSymbolImpl(Name, ExportedSymbolsOnly);
  194. }
  195. JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name,
  196. bool ExportedSymbolsOnly) {
  197. assert(findSymbolInImpl &&
  198. "findSymbolIn called, but no mock implementation was provided");
  199. return findSymbolInImpl(H, Name, ExportedSymbolsOnly);
  200. }
  201. Error emitAndFinaliez(ModuleHandleT H) {
  202. assert(emitAndFinalizeImpl &&
  203. "emitAndFinalize called, but no mock implementation was provided");
  204. return emitAndFinalizeImpl(H);
  205. }
  206. };
  207. class ReturnNullJITSymbol {
  208. public:
  209. template <typename... Args>
  210. JITSymbol operator()(Args...) const {
  211. return nullptr;
  212. }
  213. };
  214. template <typename ReturnT>
  215. class DoNothingAndReturn {
  216. public:
  217. DoNothingAndReturn(ReturnT Ret) : Ret(std::move(Ret)) {}
  218. template <typename... Args>
  219. void operator()(Args...) const { return Ret; }
  220. private:
  221. ReturnT Ret;
  222. };
  223. template <>
  224. class DoNothingAndReturn<void> {
  225. public:
  226. template <typename... Args>
  227. void operator()(Args...) const { }
  228. };
  229. } // namespace llvm
  230. #endif