CGCUDANV.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. //===----- CGCUDANV.cpp - Interface to NVIDIA CUDA Runtime ----------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This provides a class for CUDA code generation targeting the NVIDIA CUDA
  11. // runtime library.
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "CGCUDARuntime.h"
  15. #include "CodeGenFunction.h"
  16. #include "CodeGenModule.h"
  17. #include "clang/AST/Decl.h"
  18. #include "llvm/BasicBlock.h"
  19. #include "llvm/Constants.h"
  20. #include "llvm/DerivedTypes.h"
  21. #include "llvm/Support/CallSite.h"
  22. #include <vector>
  23. using namespace clang;
  24. using namespace CodeGen;
  25. namespace {
  26. class CGNVCUDARuntime : public CGCUDARuntime {
  27. private:
  28. llvm::Type *IntTy, *SizeTy;
  29. llvm::PointerType *CharPtrTy, *VoidPtrTy;
  30. llvm::Constant *getSetupArgumentFn() const;
  31. llvm::Constant *getLaunchFn() const;
  32. public:
  33. CGNVCUDARuntime(CodeGenModule &CGM);
  34. void EmitDeviceStubBody(CodeGenFunction &CGF, FunctionArgList &Args);
  35. };
  36. }
  37. CGNVCUDARuntime::CGNVCUDARuntime(CodeGenModule &CGM) : CGCUDARuntime(CGM) {
  38. CodeGen::CodeGenTypes &Types = CGM.getTypes();
  39. ASTContext &Ctx = CGM.getContext();
  40. IntTy = Types.ConvertType(Ctx.IntTy);
  41. SizeTy = Types.ConvertType(Ctx.getSizeType());
  42. CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
  43. VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
  44. }
  45. llvm::Constant *CGNVCUDARuntime::getSetupArgumentFn() const {
  46. // cudaError_t cudaSetupArgument(void *, size_t, size_t)
  47. std::vector<llvm::Type*> Params;
  48. Params.push_back(VoidPtrTy);
  49. Params.push_back(SizeTy);
  50. Params.push_back(SizeTy);
  51. return CGM.CreateRuntimeFunction(llvm::FunctionType::get(IntTy,
  52. Params, false),
  53. "cudaSetupArgument");
  54. }
  55. llvm::Constant *CGNVCUDARuntime::getLaunchFn() const {
  56. // cudaError_t cudaLaunch(char *)
  57. std::vector<llvm::Type*> Params;
  58. Params.push_back(CharPtrTy);
  59. return CGM.CreateRuntimeFunction(llvm::FunctionType::get(IntTy,
  60. Params, false),
  61. "cudaLaunch");
  62. }
  63. void CGNVCUDARuntime::EmitDeviceStubBody(CodeGenFunction &CGF,
  64. FunctionArgList &Args) {
  65. // Build the argument value list and the argument stack struct type.
  66. llvm::SmallVector<llvm::Value *, 16> ArgValues;
  67. std::vector<llvm::Type *> ArgTypes;
  68. for (FunctionArgList::const_iterator I = Args.begin(), E = Args.end();
  69. I != E; ++I) {
  70. llvm::Value *V = CGF.GetAddrOfLocalVar(*I);
  71. ArgValues.push_back(V);
  72. assert(isa<llvm::PointerType>(V->getType()) && "Arg type not PointerType");
  73. ArgTypes.push_back(cast<llvm::PointerType>(V->getType())->getElementType());
  74. }
  75. llvm::StructType *ArgStackTy = llvm::StructType::get(
  76. CGF.getLLVMContext(), ArgTypes);
  77. llvm::BasicBlock *EndBlock = CGF.createBasicBlock("setup.end");
  78. // Emit the calls to cudaSetupArgument
  79. llvm::Constant *cudaSetupArgFn = getSetupArgumentFn();
  80. for (unsigned I = 0, E = Args.size(); I != E; ++I) {
  81. llvm::Value *Args[3];
  82. llvm::BasicBlock *NextBlock = CGF.createBasicBlock("setup.next");
  83. Args[0] = CGF.Builder.CreatePointerCast(ArgValues[I], VoidPtrTy);
  84. Args[1] = CGF.Builder.CreateIntCast(
  85. llvm::ConstantExpr::getSizeOf(ArgTypes[I]),
  86. SizeTy, false);
  87. Args[2] = CGF.Builder.CreateIntCast(
  88. llvm::ConstantExpr::getOffsetOf(ArgStackTy, I),
  89. SizeTy, false);
  90. llvm::CallSite CS = CGF.EmitCallOrInvoke(cudaSetupArgFn, Args);
  91. llvm::Constant *Zero = llvm::ConstantInt::get(IntTy, 0);
  92. llvm::Value *CSZero = CGF.Builder.CreateICmpEQ(CS.getInstruction(), Zero);
  93. CGF.Builder.CreateCondBr(CSZero, NextBlock, EndBlock);
  94. CGF.EmitBlock(NextBlock);
  95. }
  96. // Emit the call to cudaLaunch
  97. llvm::Constant *cudaLaunchFn = getLaunchFn();
  98. llvm::Value *Arg = CGF.Builder.CreatePointerCast(CGF.CurFn, CharPtrTy);
  99. CGF.EmitCallOrInvoke(cudaLaunchFn, Arg);
  100. CGF.EmitBranch(EndBlock);
  101. CGF.EmitBlock(EndBlock);
  102. }
  103. CGCUDARuntime *CodeGen::CreateNVCUDARuntime(CodeGenModule &CGM) {
  104. return new CGNVCUDARuntime(CGM);
  105. }