CGCUDABuiltin.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. //===----- CGCUDABuiltin.cpp - Codegen for CUDA builtins ------------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // Generates code for built-in CUDA calls which are not runtime-specific.
  11. // (Runtime-specific codegen lives in CGCUDARuntime.)
  12. //
  13. //===----------------------------------------------------------------------===//
  14. #include "CodeGenFunction.h"
  15. #include "clang/Basic/Builtins.h"
  16. #include "llvm/IR/DataLayout.h"
  17. #include "llvm/IR/Instruction.h"
  18. #include "llvm/Support/MathExtras.h"
  19. using namespace clang;
  20. using namespace CodeGen;
  21. static llvm::Function *GetVprintfDeclaration(llvm::Module &M) {
  22. llvm::Type *ArgTypes[] = {llvm::Type::getInt8PtrTy(M.getContext()),
  23. llvm::Type::getInt8PtrTy(M.getContext())};
  24. llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(
  25. llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
  26. if (auto* F = M.getFunction("vprintf")) {
  27. // Our CUDA system header declares vprintf with the right signature, so
  28. // nobody else should have been able to declare vprintf with a bogus
  29. // signature.
  30. assert(F->getFunctionType() == VprintfFuncType);
  31. return F;
  32. }
  33. // vprintf doesn't already exist; create a declaration and insert it into the
  34. // module.
  35. return llvm::Function::Create(
  36. VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M);
  37. }
  38. // Transforms a call to printf into a call to the NVPTX vprintf syscall (which
  39. // isn't particularly special; it's invoked just like a regular function).
  40. // vprintf takes two args: A format string, and a pointer to a buffer containing
  41. // the varargs.
  42. //
  43. // For example, the call
  44. //
  45. // printf("format string", arg1, arg2, arg3);
  46. //
  47. // is converted into something resembling
  48. //
  49. // struct Tmp {
  50. // Arg1 a1;
  51. // Arg2 a2;
  52. // Arg3 a3;
  53. // };
  54. // char* buf = alloca(sizeof(Tmp));
  55. // *(Tmp*)buf = {a1, a2, a3};
  56. // vprintf("format string", buf);
  57. //
  58. // buf is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of the
  59. // args is itself aligned to its preferred alignment.
  60. //
  61. // Note that by the time this function runs, E's args have already undergone the
  62. // standard C vararg promotion (short -> int, float -> double, etc.).
  63. RValue
  64. CodeGenFunction::EmitCUDADevicePrintfCallExpr(const CallExpr *E,
  65. ReturnValueSlot ReturnValue) {
  66. assert(getLangOpts().CUDA);
  67. assert(getLangOpts().CUDAIsDevice);
  68. assert(E->getBuiltinCallee() == Builtin::BIprintf);
  69. assert(E->getNumArgs() >= 1); // printf always has at least one arg.
  70. const llvm::DataLayout &DL = CGM.getDataLayout();
  71. llvm::LLVMContext &Ctx = CGM.getLLVMContext();
  72. CallArgList Args;
  73. EmitCallArgs(Args,
  74. E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
  75. E->arguments(), E->getDirectCallee(),
  76. /* ParamsToSkip = */ 0);
  77. // Construct and fill the args buffer that we'll pass to vprintf.
  78. llvm::Value *BufferPtr;
  79. if (Args.size() <= 1) {
  80. // If there are no args, pass a null pointer to vprintf.
  81. BufferPtr = llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(Ctx));
  82. } else {
  83. llvm::SmallVector<llvm::Type *, 8> ArgTypes;
  84. for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I)
  85. ArgTypes.push_back(Args[I].RV.getScalarVal()->getType());
  86. llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args");
  87. llvm::Value *Alloca = CreateTempAlloca(AllocaTy);
  88. for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) {
  89. llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1);
  90. llvm::Value *Arg = Args[I].RV.getScalarVal();
  91. Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlignment(Arg->getType()));
  92. }
  93. BufferPtr = Builder.CreatePointerCast(Alloca, llvm::Type::getInt8PtrTy(Ctx));
  94. }
  95. // Invoke vprintf and return.
  96. llvm::Function* VprintfFunc = GetVprintfDeclaration(CGM.getModule());
  97. return RValue::get(
  98. Builder.CreateCall(VprintfFunc, {Args[0].RV.getScalarVal(), BufferPtr}));
  99. }