/* * Copyright 2017, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "Wrapper.h" #include "llvm/IR/Module.h" #include "Builtin.h" #include "Context.h" #include "GlobalAllocSPIRITPass.h" #include "RSAllocationUtils.h" #include "bcinfo/MetadataExtractor.h" #include "builder.h" #include "instructions.h" #include "module.h" #include "pass.h" #include <sstream> #include <vector> using bcinfo::MetadataExtractor; namespace android { namespace spirit { VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b, Module *m) { auto ArrTy = m->getRuntimeArrayType(elementType); const size_t stride = m->getSize(elementType); ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride); auto StructTy = m->getStructType(ArrTy); StructTy->decorate(Decoration::BufferBlock); StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0); auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy); VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform); bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding); m->addVariable(bufferVar); return bufferVar; } bool AddWrapper(const char *name, const uint32_t signature, const uint32_t numInput, Builder &b, Module *m) { FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name); if (kernel == nullptr) { // In the metadata for RenderScript LLVM bitcode, the first foreach kernel // is always reserved for the root kernel, even though in the most recent RS // apps it does not exist. Simply bypass wrapper generation here, and return // true for this case. // Otherwise, if a non-root kernel function cannot be found, it is a // fatal internal error which is really unexpected. return (strncmp(name, "root", 4) == 0); } // The following three cases are not supported if (!MetadataExtractor::hasForEachSignatureKernel(signature)) { // Not handling old-style kernel return false; } if (MetadataExtractor::hasForEachSignatureUsrData(signature)) { // Not handling the user argument return false; } if (MetadataExtractor::hasForEachSignatureCtxt(signature)) { // Not handling the context argument return false; } TypeVoidInst *VoidTy = m->getVoidType(); TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0); FunctionDefinition *Func = b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy); m->addFunctionDefinition(Func); Block *Blk = b.MakeBlock(); Func->addBlock(Blk); Blk->addInstruction(b.MakeLabel()); TypeIntInst *UIntTy = m->getUnsignedIntType(32); Instruction *XValue = nullptr; Instruction *YValue = nullptr; Instruction *ZValue = nullptr; Instruction *Index = nullptr; VariableInst *InvocationId = nullptr; VariableInst *NumWorkgroups = nullptr; if (MetadataExtractor::hasForEachSignatureIn(signature) || MetadataExtractor::hasForEachSignatureOut(signature) || MetadataExtractor::hasForEachSignatureX(signature) || MetadataExtractor::hasForEachSignatureY(signature) || MetadataExtractor::hasForEachSignatureZ(signature)) { TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3); InvocationId = m->getInvocationId(); auto IID = b.MakeLoad(V3UIntTy, InvocationId); Blk->addInstruction(IID); XValue = b.MakeCompositeExtract(UIntTy, IID, {0}); Blk->addInstruction(XValue); YValue = b.MakeCompositeExtract(UIntTy, IID, {1}); Blk->addInstruction(YValue); ZValue = b.MakeCompositeExtract(UIntTy, IID, {2}); Blk->addInstruction(ZValue); // TODO: Use SpecConstant for workgroup size auto ConstOne = m->getConstant(UIntTy, 1U); auto GroupSize = m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne); auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0}); Blk->addInstruction(GroupSizeX); auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1}); Blk->addInstruction(GroupSizeY); NumWorkgroups = m->getNumWorkgroups(); auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups); Blk->addInstruction(NumGroup); auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0}); Blk->addInstruction(NumGroupX); auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1}); Blk->addInstruction(NumGroupY); auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX); Blk->addInstruction(GlobalSizeX); auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY); Blk->addInstruction(GlobalSizeY); auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue); Blk->addInstruction(RowsAlongZ); auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ); Blk->addInstruction(NumRows); auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows); Blk->addInstruction(NumCellsFromYZ); Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue); Blk->addInstruction(Index); } std::vector<IdRef> inputs; ConstantInst *ConstZero = m->getConstant(UIntTy, 0); for (uint32_t i = 0; i < numInput; i++) { FunctionParameterInst *param = kernel->getParameter(i); Instruction *elementType = param->mResultType.mInstruction; VariableInst *inputBuffer = AddBuffer(elementType, i + 3, b, m); TypePointerInst *PtrTy = m->getPointerType(StorageClass::Function, elementType); AccessChainInst *Ptr = b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index}); Blk->addInstruction(Ptr); Instruction *input = b.MakeLoad(elementType, Ptr); Blk->addInstruction(input); inputs.push_back(IdRef(input)); } // TODO: Convert from unsigned int to signed int if that is what the kernel // function takes for the coordinate parameters if (MetadataExtractor::hasForEachSignatureX(signature)) { inputs.push_back(XValue); if (MetadataExtractor::hasForEachSignatureY(signature)) { inputs.push_back(YValue); if (MetadataExtractor::hasForEachSignatureZ(signature)) { inputs.push_back(ZValue); } } } auto resultType = kernel->getReturnType(); auto kernelCall = b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs); Blk->addInstruction(kernelCall); if (MetadataExtractor::hasForEachSignatureOut(signature)) { VariableInst *OutputBuffer = AddBuffer(resultType, 2, b, m); auto resultPtrType = m->getPointerType(StorageClass::Function, resultType); AccessChainInst *OutPtr = b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index}); Blk->addInstruction(OutPtr); Blk->addInstruction(b.MakeStore(OutPtr, kernelCall)); } Blk->addInstruction(b.MakeReturn()); std::string wrapperName("entry_"); wrapperName.append(name); EntryPointDefinition *entry = b.MakeEntryPointDefinition( ExecutionModel::GLCompute, Func, wrapperName.c_str()); entry->setLocalSize(1, 1, 1); if (Index != nullptr) { entry->addToInterface(InvocationId); entry->addToInterface(NumWorkgroups); } m->addEntryPoint(entry); return true; } bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) { Instruction *inst = m->lookupByName("__GPUBlock"); if (inst == nullptr) { return true; } VariableInst *bufferVar = static_cast<VariableInst *>(inst); bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); bufferVar->decorate(Decoration::Binding)->addExtraOperand(0); TypePointerInst *StructPtrTy = static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction); TypeStructInst *StructTy = static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction); StructTy->decorate(Decoration::BufferBlock); // Decorate each member with proper offsets const auto GlobalsB = LM.globals().begin(); const auto GlobalsE = LM.globals().end(); const auto Found = std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) { return GV.getName() == "__GPUBlock"; }); if (Found == GlobalsE) { return true; // GPUBlock not found - not an error by itself. } const llvm::GlobalVariable &G = *Found; rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance(); bool IsCorrectTy = false; if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) { if (auto *LStructTy = llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) { IsCorrectTy = true; const auto &DLayout = LM.getDataLayout(); const auto *SLayout = DLayout.getStructLayout(LStructTy); assert(SLayout); if (SLayout == nullptr) { std::cerr << "struct layout is null" << std::endl; return false; } std::vector<uint32_t> offsets; for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) { auto decor = StructTy->memberDecorate(i, Decoration::Offset); if (!decor) { std::cerr << "failed creating member decoration for field " << i << std::endl; return false; } const uint32_t offset = (uint32_t)SLayout->getElementOffset(i); decor->addExtraOperand(offset); offsets.push_back(offset); } std::stringstream ssOffsets; // TODO: define this string in a central place ssOffsets << ".rsov.ExportedVars:"; for(uint32_t slot = 0; slot < Ctxt.getNumExportVar(); slot++) { const uint32_t index = Ctxt.getExportVarIndex(slot); const uint32_t offset = offsets[index]; ssOffsets << offset << ';'; } m->addString(ssOffsets.str().c_str()); std::stringstream ssGlobalSize; ssGlobalSize << ".rsov.GlobalSize:" << Ctxt.getGlobalSize(); m->addString(ssGlobalSize.str().c_str()); } } if (!IsCorrectTy) { return false; } llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs; if (!getRSAllocationInfo(LM, RSAllocs)) { // llvm::errs() << "Extracting rs_allocation info failed\n"; return true; } // TODO: clean up the binding number assignment size_t BindingNum = 3; for (const auto &A : RSAllocs) { Instruction *inst = m->lookupByName(A.VarName.c_str()); if (inst == nullptr) { return false; } VariableInst *bufferVar = static_cast<VariableInst *>(inst); bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++); } return true; } void AddHeader(Module *m) { m->addCapability(Capability::Shader); m->setMemoryModel(AddressingModel::Logical, MemoryModel::GLSL450); m->addSource(SourceLanguage::GLSL, 450); m->addSourceExtension("GL_ARB_separate_shader_objects"); m->addSourceExtension("GL_ARB_shading_language_420pack"); m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive"); m->addSourceExtension("GL_GOOGLE_include_directive"); } namespace { class StorageClassVisitor : public DoNothingVisitor { public: void visit(TypePointerInst *inst) override { matchAndReplace(inst->mOperand1); } void visit(TypeForwardPointerInst *inst) override { matchAndReplace(inst->mOperand2); } void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); } private: void matchAndReplace(StorageClass &storage) { if (storage == StorageClass::Function) { storage = StorageClass::Uniform; } } }; void FixGlobalStorageClass(Module *m) { StorageClassVisitor v; m->getGlobalSection()->accept(&v); } } // anonymous namespace bool AddWrappers(llvm::Module &LM, android::spirit::Module *m) { rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance(); const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata(); android::spirit::Builder b; m->setBuilder(&b); FixGlobalStorageClass(m); AddHeader(m); DecorateGlobalBuffer(LM, b, m); const size_t numKernel = metadata.getExportForEachSignatureCount(); const char **kernelName = metadata.getExportForEachNameList(); const uint32_t *kernelSigature = metadata.getExportForEachSignatureList(); const uint32_t *inputCount = metadata.getExportForEachInputCountList(); for (size_t i = 0; i < numKernel; i++) { bool success = AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m); if (!success) { return false; } } m->consolidateAnnotations(); return true; } class WrapperPass : public Pass { public: WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {} Module *run(Module *m, int *error) override { bool success = AddWrappers(mLLVMModule, m); if (error) { *error = success ? 0 : -1; } return m; } private: llvm::Module &mLLVMModule; }; } // namespace spirit } // namespace android namespace rs2spirv { android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) { return new android::spirit::WrapperPass(LLVMModule); } } // namespace rs2spirv