/* * Copyright 2015, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "bcc/Assert.h" #include "bcc/Renderscript/RSUtils.h" #include "bcc/Support/Log.h" #include <algorithm> #include <vector> #include <llvm/IR/CallSite.h> #include <llvm/IR/Type.h> #include <llvm/IR/Instructions.h> #include <llvm/IR/Module.h> #include <llvm/IR/Function.h> #include <llvm/Pass.h> namespace { // anonymous namespace static const bool kDebug = false; /* RSX86_64CallConvPass: This pass scans for calls to Renderscript functions in * the CPU reference driver. For such calls, it identifies the * pass-by-reference large-object pointer arguments introduced by the frontend * to conform to the AArch64 calling convention (AAPCS). These pointer * arguments are converted to pass-by-value to match the calling convention of * the CPU reference driver. */ class RSX86_64CallConvPass: public llvm::ModulePass { private: bool IsRSFunctionOfInterest(llvm::Function &F) { // Only Renderscript functions that are not defined locally be considered if (!F.empty()) // defined locally return false; // llvm intrinsic or internal function llvm::StringRef FName = F.getName(); if (FName.startswith("llvm.")) return false; // All other functions need to be checked for large-object parameters. // Disallowed (non-Renderscript) functions are detected by a different pass. return true; } // Test if this argument needs to be converted to pass-by-value. bool IsDerefNeeded(llvm::Function *F, llvm::Argument &Arg) { unsigned ArgNo = Arg.getArgNo(); llvm::Type *ArgTy = Arg.getType(); // Do not consider arguments with 'sret' attribute. Parameters with this // attribute are actually pointers to structure return values. if (Arg.hasStructRetAttr()) return false; // Dereference needed only if type is a pointer to a struct if (!ArgTy->isPointerTy() || !ArgTy->getPointerElementType()->isStructTy()) return false; // Dereference needed only for certain RS struct objects. llvm::Type *StructTy = ArgTy->getPointerElementType(); if (!isRsObjectType(StructTy)) return false; // TODO Find a better way to encode exceptions llvm::StringRef FName = F->getName(); // rsSetObject's first parameter is a pointer if (FName.find("rsSetObject") != std::string::npos && ArgNo == 0) return false; // rsClearObject's first parameter is a pointer if (FName.find("rsClearObject") != std::string::npos && ArgNo == 0) return false; return true; } // Compute which arguments to this function need be converted to pass-by-value bool FillArgsToDeref(llvm::Function *F, std::vector<unsigned> &ArgNums) { bccAssert(ArgNums.size() == 0); for (auto &Arg: F->getArgumentList()) { if (IsDerefNeeded(F, Arg)) { ArgNums.push_back(Arg.getArgNo()); if (kDebug) { ALOGV("Lowering argument %u for function %s\n", Arg.getArgNo(), F->getName().str().c_str()); } } } return ArgNums.size() > 0; } llvm::Function *RedefineFn(llvm::Function *OrigFn, std::vector<unsigned> &ArgsToDeref) { llvm::FunctionType *FTy = OrigFn->getFunctionType(); std::vector<llvm::Type *> Params(FTy->param_begin(), FTy->param_end()); llvm::FunctionType *NewTy = llvm::FunctionType::get(FTy->getReturnType(), Params, FTy->isVarArg()); llvm::Function *NewFn = llvm::Function::Create(NewTy, OrigFn->getLinkage(), OrigFn->getName(), OrigFn->getParent()); // Add the ByVal attribute to the attribute list corresponding to this // argument. The list at index (i+1) corresponds to the i-th argument. The // list at index 0 corresponds to the return value's attribute. for (auto i: ArgsToDeref) { NewFn->addAttribute(i+1, llvm::Attribute::ByVal); } NewFn->copyAttributesFrom(OrigFn); NewFn->takeName(OrigFn); for (auto AI=OrigFn->arg_begin(), AE=OrigFn->arg_end(), NAI=NewFn->arg_begin(); AI != AE; ++ AI, ++NAI) { NAI->takeName(AI); } return NewFn; } void ReplaceCallInsn(llvm::CallSite &CS, llvm::Function *NewFn, std::vector<unsigned> &ArgsToDeref) { llvm::CallInst *CI = llvm::cast<llvm::CallInst>(CS.getInstruction()); std::vector<llvm::Value *> Args(CS.arg_begin(), CS.arg_end()); auto NewCI = llvm::CallInst::Create(NewFn, Args, "", CI); // Add the ByVal attribute to the attribute list corresponding to this // argument. The list at index (i+1) corresponds to the i-th argument. The // list at index 0 corresponds to the return value's attribute. for (auto i: ArgsToDeref) { NewCI->addAttribute(i+1, llvm::Attribute::ByVal); } if (CI->isTailCall()) NewCI->setTailCall(); if (!CI->getType()->isVoidTy()) CI->replaceAllUsesWith(NewCI); CI->eraseFromParent(); } public: static char ID; RSX86_64CallConvPass() : ModulePass (ID) { } virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { // This pass does not use any other analysis passes, but it does // modify the existing functions in the module (thus altering the CFG). } bool runOnModule(llvm::Module &M) override { // Avoid adding Functions and altering FunctionList while iterating over it // by collecting functions and processing them later. std::vector<llvm::Function *> FunctionsToHandle; auto &FunctionList = M.getFunctionList(); for (auto &OrigFn: FunctionList) { if (!IsRSFunctionOfInterest(OrigFn)) continue; FunctionsToHandle.push_back(&OrigFn); } for (auto OrigFn: FunctionsToHandle) { std::vector<unsigned> ArgsToDeref; if (!FillArgsToDeref(OrigFn, ArgsToDeref)) continue; // Replace all calls to OrigFn and erase it from parent. llvm::Function *NewFn = RedefineFn(OrigFn, ArgsToDeref); while (!OrigFn->use_empty()) { llvm::CallSite CS(OrigFn->user_back()); ReplaceCallInsn(CS, NewFn, ArgsToDeref); } OrigFn->eraseFromParent(); } return FunctionsToHandle.size() > 0; } }; } char RSX86_64CallConvPass::ID = 0; static llvm::RegisterPass<RSX86_64CallConvPass> X("X86-64-calling-conv", "remove AArch64 assumptions from calls in X86-64"); namespace bcc { llvm::ModulePass * createRSX86_64CallConvPass() { return new RSX86_64CallConvPass(); } }