/*
* Copyright 2017, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "Wrapper.h"
#include "llvm/IR/Module.h"
#include "Builtin.h"
#include "Context.h"
#include "GlobalAllocSPIRITPass.h"
#include "RSAllocationUtils.h"
#include "bcinfo/MetadataExtractor.h"
#include "builder.h"
#include "instructions.h"
#include "module.h"
#include "pass.h"
#include <sstream>
#include <vector>
using bcinfo::MetadataExtractor;
namespace android {
namespace spirit {
VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
Module *m) {
auto ArrTy = m->getRuntimeArrayType(elementType);
const size_t stride = m->getSize(elementType);
ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
auto StructTy = m->getStructType(ArrTy);
StructTy->decorate(Decoration::BufferBlock);
StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
m->addVariable(bufferVar);
return bufferVar;
}
bool AddWrapper(const char *name, const uint32_t signature,
const uint32_t numInput, Builder &b, Module *m) {
FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
if (kernel == nullptr) {
// In the metadata for RenderScript LLVM bitcode, the first foreach kernel
// is always reserved for the root kernel, even though in the most recent RS
// apps it does not exist. Simply bypass wrapper generation here, and return
// true for this case.
// Otherwise, if a non-root kernel function cannot be found, it is a
// fatal internal error which is really unexpected.
return (strncmp(name, "root", 4) == 0);
}
// The following three cases are not supported
if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
// Not handling old-style kernel
return false;
}
if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
// Not handling the user argument
return false;
}
if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
// Not handling the context argument
return false;
}
TypeVoidInst *VoidTy = m->getVoidType();
TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
FunctionDefinition *Func =
b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
m->addFunctionDefinition(Func);
Block *Blk = b.MakeBlock();
Func->addBlock(Blk);
Blk->addInstruction(b.MakeLabel());
TypeIntInst *UIntTy = m->getUnsignedIntType(32);
Instruction *XValue = nullptr;
Instruction *YValue = nullptr;
Instruction *ZValue = nullptr;
Instruction *Index = nullptr;
VariableInst *InvocationId = nullptr;
VariableInst *NumWorkgroups = nullptr;
if (MetadataExtractor::hasForEachSignatureIn(signature) ||
MetadataExtractor::hasForEachSignatureOut(signature) ||
MetadataExtractor::hasForEachSignatureX(signature) ||
MetadataExtractor::hasForEachSignatureY(signature) ||
MetadataExtractor::hasForEachSignatureZ(signature)) {
TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
InvocationId = m->getInvocationId();
auto IID = b.MakeLoad(V3UIntTy, InvocationId);
Blk->addInstruction(IID);
XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
Blk->addInstruction(XValue);
YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
Blk->addInstruction(YValue);
ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
Blk->addInstruction(ZValue);
// TODO: Use SpecConstant for workgroup size
auto ConstOne = m->getConstant(UIntTy, 1U);
auto GroupSize =
m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
Blk->addInstruction(GroupSizeX);
auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
Blk->addInstruction(GroupSizeY);
NumWorkgroups = m->getNumWorkgroups();
auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
Blk->addInstruction(NumGroup);
auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
Blk->addInstruction(NumGroupX);
auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
Blk->addInstruction(NumGroupY);
auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
Blk->addInstruction(GlobalSizeX);
auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
Blk->addInstruction(GlobalSizeY);
auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
Blk->addInstruction(RowsAlongZ);
auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
Blk->addInstruction(NumRows);
auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
Blk->addInstruction(NumCellsFromYZ);
Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
Blk->addInstruction(Index);
}
std::vector<IdRef> inputs;
ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
for (uint32_t i = 0; i < numInput; i++) {
FunctionParameterInst *param = kernel->getParameter(i);
Instruction *elementType = param->mResultType.mInstruction;
VariableInst *inputBuffer = AddBuffer(elementType, i + 3, b, m);
TypePointerInst *PtrTy =
m->getPointerType(StorageClass::Function, elementType);
AccessChainInst *Ptr =
b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
Blk->addInstruction(Ptr);
Instruction *input = b.MakeLoad(elementType, Ptr);
Blk->addInstruction(input);
inputs.push_back(IdRef(input));
}
// TODO: Convert from unsigned int to signed int if that is what the kernel
// function takes for the coordinate parameters
if (MetadataExtractor::hasForEachSignatureX(signature)) {
inputs.push_back(XValue);
if (MetadataExtractor::hasForEachSignatureY(signature)) {
inputs.push_back(YValue);
if (MetadataExtractor::hasForEachSignatureZ(signature)) {
inputs.push_back(ZValue);
}
}
}
auto resultType = kernel->getReturnType();
auto kernelCall =
b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
Blk->addInstruction(kernelCall);
if (MetadataExtractor::hasForEachSignatureOut(signature)) {
VariableInst *OutputBuffer = AddBuffer(resultType, 2, b, m);
auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
AccessChainInst *OutPtr =
b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
Blk->addInstruction(OutPtr);
Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
}
Blk->addInstruction(b.MakeReturn());
std::string wrapperName("entry_");
wrapperName.append(name);
EntryPointDefinition *entry = b.MakeEntryPointDefinition(
ExecutionModel::GLCompute, Func, wrapperName.c_str());
entry->setLocalSize(1, 1, 1);
if (Index != nullptr) {
entry->addToInterface(InvocationId);
entry->addToInterface(NumWorkgroups);
}
m->addEntryPoint(entry);
return true;
}
bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
Instruction *inst = m->lookupByName("__GPUBlock");
if (inst == nullptr) {
return true;
}
VariableInst *bufferVar = static_cast<VariableInst *>(inst);
bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
TypePointerInst *StructPtrTy =
static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
TypeStructInst *StructTy =
static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
StructTy->decorate(Decoration::BufferBlock);
// Decorate each member with proper offsets
const auto GlobalsB = LM.globals().begin();
const auto GlobalsE = LM.globals().end();
const auto Found =
std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
return GV.getName() == "__GPUBlock";
});
if (Found == GlobalsE) {
return true; // GPUBlock not found - not an error by itself.
}
const llvm::GlobalVariable &G = *Found;
rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
bool IsCorrectTy = false;
if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
if (auto *LStructTy =
llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
IsCorrectTy = true;
const auto &DLayout = LM.getDataLayout();
const auto *SLayout = DLayout.getStructLayout(LStructTy);
assert(SLayout);
if (SLayout == nullptr) {
std::cerr << "struct layout is null" << std::endl;
return false;
}
std::vector<uint32_t> offsets;
for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
auto decor = StructTy->memberDecorate(i, Decoration::Offset);
if (!decor) {
std::cerr << "failed creating member decoration for field " << i
<< std::endl;
return false;
}
const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
decor->addExtraOperand(offset);
offsets.push_back(offset);
}
std::stringstream ssOffsets;
// TODO: define this string in a central place
ssOffsets << ".rsov.ExportedVars:";
for(uint32_t slot = 0; slot < Ctxt.getNumExportVar(); slot++) {
const uint32_t index = Ctxt.getExportVarIndex(slot);
const uint32_t offset = offsets[index];
ssOffsets << offset << ';';
}
m->addString(ssOffsets.str().c_str());
std::stringstream ssGlobalSize;
ssGlobalSize << ".rsov.GlobalSize:" << Ctxt.getGlobalSize();
m->addString(ssGlobalSize.str().c_str());
}
}
if (!IsCorrectTy) {
return false;
}
llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
if (!getRSAllocationInfo(LM, RSAllocs)) {
// llvm::errs() << "Extracting rs_allocation info failed\n";
return true;
}
// TODO: clean up the binding number assignment
size_t BindingNum = 3;
for (const auto &A : RSAllocs) {
Instruction *inst = m->lookupByName(A.VarName.c_str());
if (inst == nullptr) {
return false;
}
VariableInst *bufferVar = static_cast<VariableInst *>(inst);
bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
}
return true;
}
void AddHeader(Module *m) {
m->addCapability(Capability::Shader);
m->setMemoryModel(AddressingModel::Logical, MemoryModel::GLSL450);
m->addSource(SourceLanguage::GLSL, 450);
m->addSourceExtension("GL_ARB_separate_shader_objects");
m->addSourceExtension("GL_ARB_shading_language_420pack");
m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
m->addSourceExtension("GL_GOOGLE_include_directive");
}
namespace {
class StorageClassVisitor : public DoNothingVisitor {
public:
void visit(TypePointerInst *inst) override {
matchAndReplace(inst->mOperand1);
}
void visit(TypeForwardPointerInst *inst) override {
matchAndReplace(inst->mOperand2);
}
void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
private:
void matchAndReplace(StorageClass &storage) {
if (storage == StorageClass::Function) {
storage = StorageClass::Uniform;
}
}
};
void FixGlobalStorageClass(Module *m) {
StorageClassVisitor v;
m->getGlobalSection()->accept(&v);
}
} // anonymous namespace
bool AddWrappers(llvm::Module &LM,
android::spirit::Module *m) {
rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
android::spirit::Builder b;
m->setBuilder(&b);
FixGlobalStorageClass(m);
AddHeader(m);
DecorateGlobalBuffer(LM, b, m);
const size_t numKernel = metadata.getExportForEachSignatureCount();
const char **kernelName = metadata.getExportForEachNameList();
const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
const uint32_t *inputCount = metadata.getExportForEachInputCountList();
for (size_t i = 0; i < numKernel; i++) {
bool success =
AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
if (!success) {
return false;
}
}
m->consolidateAnnotations();
return true;
}
class WrapperPass : public Pass {
public:
WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
Module *run(Module *m, int *error) override {
bool success = AddWrappers(mLLVMModule, m);
if (error) {
*error = success ? 0 : -1;
}
return m;
}
private:
llvm::Module &mLLVMModule;
};
} // namespace spirit
} // namespace android
namespace rs2spirv {
android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
return new android::spirit::WrapperPass(LLVMModule);
}
} // namespace rs2spirv