//===- SPIRVLowerOCLBlocks.cpp - Lower OpenCL blocks ------------*- C++ -*-===//
//
//                     The LLVM/SPIR-V Translator
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
// Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal with the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimers.
// Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimers in the documentation
// and/or other materials provided with the distribution.
// Neither the names of Advanced Micro Devices, Inc., nor the names of its
// contributors may be used to endorse or promote products derived from this
// Software without specific prior written permission.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
// THE SOFTWARE.
//
//===----------------------------------------------------------------------===//
/// \file
///
/// This file implements lowering of OpenCL blocks to functions.
///
//===----------------------------------------------------------------------===//

#ifndef OCLLOWERBLOCKS_H_
#define OCLLOWERBLOCKS_H_

#include "SPIRVInternal.h"
#include "OCLUtil.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Bitcode/ReaderWriter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/PassSupport.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Transforms/Utils/Cloning.h"

#include <iostream>
#include <list>
#include <memory>
#include <set>
#include <sstream>
#include <vector>

#define DEBUG_TYPE "spvblocks"

using namespace llvm;
using namespace SPIRV;
using namespace OCLUtil;

namespace SPIRV{

/// Lower SPIR2 blocks to function calls.
///
/// SPIR2 representation of blocks:
///
/// block = spir_block_bind(bitcast(block_func), context_len, context_align,
///   context)
/// block_func_ptr = bitcast(spir_get_block_invoke(block))
/// context_ptr = spir_get_block_context(block)
/// ret = block_func_ptr(context_ptr, args)
///
/// Propagates block_func to each spir_get_block_invoke through def-use chain of
/// spir_block_bind, so that
/// ret = block_func(context, args)
class SPIRVLowerOCLBlocks: public ModulePass {
public:
  SPIRVLowerOCLBlocks():ModulePass(ID), M(nullptr){
    initializeSPIRVLowerOCLBlocksPass(*PassRegistry::getPassRegistry());
  }

  virtual void getAnalysisUsage(AnalysisUsage &AU) const {
    AU.addRequired<CallGraphWrapperPass>();
    //AU.addRequired<AliasAnalysis>();
    AU.addRequired<AssumptionCacheTracker>();
  }

  virtual bool runOnModule(Module &Module) {
    M = &Module;
    lowerBlockBind();
    lowerGetBlockInvoke();
    lowerGetBlockContext();
    erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
    erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT));
    erase(M->getFunction(SPIR_INTRINSIC_BLOCK_BIND));
    DEBUG(dbgs() << "------- After OCLLowerBlocks ------------\n" <<
                    *M << '\n');
    return true;
  }

  static char ID;
private:
  const static int MaxIter = 1000;
  Module *M;

  bool
  lowerBlockBind() {
    auto F = M->getFunction(SPIR_INTRINSIC_BLOCK_BIND);
    if (!F)
      return false;
    int Iter = MaxIter;
    while(lowerBlockBind(F) && Iter > 0){
      Iter--;
      DEBUG(dbgs() << "-------------- after iteration " << MaxIter - Iter <<
          " --------------\n" << *M << '\n');
    }
    assert(Iter > 0 && "Too many iterations");
    return true;
  }

  bool
  eraseUselessFunctions() {
    bool changed = false;
    for (auto I = M->begin(), E = M->end(); I != E;) {
      Function *F = static_cast<Function*>(I++);
      if (!GlobalValue::isInternalLinkage(F->getLinkage()) &&
          !F->isDeclaration())
        continue;

      dumpUsers(F, "[eraseUselessFunctions] ");
      for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
        auto U = *UI++;
        if (auto CE = dyn_cast<ConstantExpr>(U)){
          if (CE->use_empty()) {
            CE->dropAllReferences();
            changed = true;
          }
        }
      }
      if (F->use_empty()) {
        erase(F);
        changed = true;
      }
    }
    return changed;
  }

  void
  lowerGetBlockInvoke() {
    if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)) {
      for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
        auto CI = dyn_cast<CallInst>(*UI++);
        assert(CI && "Invalid usage of spir_get_block_invoke");
        lowerGetBlockInvoke(CI);
      }
    }
  }

  void
  lowerGetBlockContext() {
    if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)) {
      for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
        auto CI = dyn_cast<CallInst>(*UI++);
        assert(CI && "Invalid usage of spir_get_block_context");
        lowerGetBlockContext(CI);
      }
    }
  }
  /// Lower calls of spir_block_bind.
  /// Return true if the Module is changed.
  bool
  lowerBlockBind(Function *BlockBindFunc) {
    bool changed = false;
    for (auto I = BlockBindFunc->user_begin(), E = BlockBindFunc->user_end();
        I != E;) {
      DEBUG(dbgs() << "[lowerBlockBind] " << **I << '\n');
      // Handle spir_block_bind(bitcast(block_func), context_len,
      // context_align, context)
      auto CallBlkBind = cast<CallInst>(*I++);
      Function *InvF = nullptr;
      Value *Ctx = nullptr;
      Value *CtxLen = nullptr;
      Value *CtxAlign = nullptr;
      getBlockInvokeFuncAndContext(CallBlkBind, &InvF, &Ctx, &CtxLen,
          &CtxAlign);
      for (auto II = CallBlkBind->user_begin(), EE = CallBlkBind->user_end();
          II != EE;) {
        auto BlkUser = *II++;
        SPIRVDBG(dbgs() << "  Block user: " << *BlkUser << '\n');
        if (auto Ret = dyn_cast<ReturnInst>(BlkUser)) {
          bool Inlined = false;
          changed |= lowerReturnBlock(Ret, CallBlkBind, Inlined);
          if (Inlined)
            return true;
        } else if (auto CI = dyn_cast<CallInst>(BlkUser)){
          auto CallBindF = CI->getCalledFunction();
          auto Name = CallBindF->getName();
          std::string DemangledName;
          if (Name == SPIR_INTRINSIC_GET_BLOCK_INVOKE) {
            assert(CI->getArgOperand(0) == CallBlkBind);
            changed |= lowerGetBlockInvoke(CI, cast<Function>(InvF));
          } else if (Name == SPIR_INTRINSIC_GET_BLOCK_CONTEXT) {
            assert(CI->getArgOperand(0) == CallBlkBind);
            // Handle context_ptr = spir_get_block_context(block)
            lowerGetBlockContext(CI, Ctx);
            changed = true;
          } else if (oclIsBuiltin(Name, &DemangledName)) {
            lowerBlockBuiltin(CI, InvF, Ctx, CtxLen, CtxAlign, DemangledName);
            changed = true;
          } else
            llvm_unreachable("Invalid block user");
        }
      }
      erase(CallBlkBind);
    }
    changed |= eraseUselessFunctions();
    return changed;
  }

  void
  lowerGetBlockContext(CallInst *CallGetBlkCtx, Value *Ctx = nullptr) {
    if (!Ctx)
      getBlockInvokeFuncAndContext(CallGetBlkCtx->getArgOperand(0), nullptr,
          &Ctx);
    CallGetBlkCtx->replaceAllUsesWith(Ctx);
    DEBUG(dbgs() << "  [lowerGetBlockContext] " << *CallGetBlkCtx << " => " <<
        *Ctx << "\n\n");
    erase(CallGetBlkCtx);
  }

  bool
  lowerGetBlockInvoke(CallInst *CallGetBlkInvoke,
      Function *InvokeF = nullptr) {
    bool changed = false;
    for (auto UI = CallGetBlkInvoke->user_begin(),
        UE = CallGetBlkInvoke->user_end();
        UI != UE;) {
      // Handle block_func_ptr = bitcast(spir_get_block_invoke(block))
      auto CallInv = cast<Instruction>(*UI++);
      auto Cast = dyn_cast<BitCastInst>(CallInv);
      if (Cast)
        CallInv = dyn_cast<Instruction>(*CallInv->user_begin());
      DEBUG(dbgs() << "[lowerGetBlockInvoke]  " << *CallInv);
      // Handle ret = block_func_ptr(context_ptr, args)
      auto CI = cast<CallInst>(CallInv);
      auto F = CI->getCalledValue();
      if (InvokeF == nullptr) {
        getBlockInvokeFuncAndContext(CallGetBlkInvoke->getArgOperand(0),
            &InvokeF, nullptr);
        assert(InvokeF);
      }
      assert(F->getType() == InvokeF->getType());
      CI->replaceUsesOfWith(F, InvokeF);
      DEBUG(dbgs() << " => " << *CI << "\n\n");
      erase(Cast);
      changed = true;
    }
    erase(CallGetBlkInvoke);
    return changed;
  }

  void
  lowerBlockBuiltin(CallInst *CI, Function *InvF, Value *Ctx, Value *CtxLen,
      Value *CtxAlign, const std::string& DemangledName) {
    mutateCallInstSPIRV (M, CI, [=](CallInst *CI, std::vector<Value *> &Args) {
      size_t I = 0;
      size_t E = Args.size();
      for (; I != E; ++I) {
        if (isPointerToOpaqueStructType(Args[I]->getType(),
            SPIR_TYPE_NAME_BLOCK_T)) {
          break;
        }
      }
      assert (I < E);
      Args[I] = castToVoidFuncPtr(InvF);
      if (I + 1 == E) {
        Args.push_back(Ctx);
        Args.push_back(CtxLen);
        Args.push_back(CtxAlign);
      } else {
        Args.insert(Args.begin() + I + 1, CtxAlign);
        Args.insert(Args.begin() + I + 1, CtxLen);
        Args.insert(Args.begin() + I + 1, Ctx);
      }
      if (DemangledName == kOCLBuiltinName::EnqueueKernel) {
        // Insert event arguments if there are not.
        if (!isa<IntegerType>(Args[3]->getType())) {
          Args.insert(Args.begin() + 3, getInt32(M, 0));
          Args.insert(Args.begin() + 4, getOCLNullClkEventPtr());
        }
        if (!isOCLClkEventPtrType(Args[5]->getType()))
          Args.insert(Args.begin() + 5, getOCLNullClkEventPtr());
      }
      return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(DemangledName));
    });
  }
  /// Transform return of a block.
  /// The function returning a block is inlined since the context cannot be
  /// passed to another function.
  /// Returns true of module is changed.
  bool
  lowerReturnBlock(ReturnInst *Ret, Value *CallBlkBind, bool &Inlined) {
    auto F = Ret->getParent()->getParent();
    auto changed = false;
    for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
      auto U = *UI++;
      dumpUsers(U);
      auto Inst = dyn_cast<Instruction>(U);
      if (Inst && Inst->use_empty()) {
        erase(Inst);
        changed = true;
        continue;
      }
      auto CI = dyn_cast<CallInst>(U);
      if(!CI || CI->getCalledFunction() != F)
        continue;

      DEBUG(dbgs() << "[lowerReturnBlock] inline " << F->getName() << '\n');
      auto CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
      auto ACT = &getAnalysis<AssumptionCacheTracker>();
      //auto AA = &getAnalysis<AliasAnalysis>();
      //InlineFunctionInfo IFI(CG, M->getDataLayout(), AA, ACT);
      InlineFunctionInfo IFI(CG, ACT);
      InlineFunction(CI, IFI);
      Inlined = true;
    }
    return changed || Inlined;
  }

  void
  getBlockInvokeFuncAndContext(Value *Blk, Function **PInvF, Value **PCtx,
      Value **PCtxLen = nullptr, Value **PCtxAlign = nullptr){
    Function *InvF = nullptr;
    Value *Ctx = nullptr;
    Value *CtxLen = nullptr;
    Value *CtxAlign = nullptr;
    if (auto CallBlkBind = dyn_cast<CallInst>(Blk)) {
      assert(CallBlkBind->getCalledFunction()->getName() ==
          SPIR_INTRINSIC_BLOCK_BIND && "Invalid block");
      InvF = dyn_cast<Function>(
          CallBlkBind->getArgOperand(0)->stripPointerCasts());
      CtxLen = CallBlkBind->getArgOperand(1);
      CtxAlign = CallBlkBind->getArgOperand(2);
      Ctx = CallBlkBind->getArgOperand(3);
    } else if (auto F = dyn_cast<Function>(Blk->stripPointerCasts())) {
      InvF = F;
      Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
    } else if (auto Load = dyn_cast<LoadInst>(Blk)) {
      auto Op = Load->getPointerOperand();
      if (auto GV = dyn_cast<GlobalVariable>(Op)) {
        if (GV->isConstant()) {
          InvF = cast<Function>(GV->getInitializer()->stripPointerCasts());
          Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
        } else {
          llvm_unreachable("load non-constant block?");
        }
      } else {
        llvm_unreachable("Loading block from non global?");
      }
    } else {
      llvm_unreachable("Invalid block");
    }
    DEBUG(dbgs() << "  Block invocation func: " << InvF->getName() << '\n' <<
        "  Block context: " << *Ctx << '\n');
    assert(InvF && Ctx && "Invalid block");
    if (PInvF)
      *PInvF = InvF;
    if (PCtx)
      *PCtx = Ctx;
    if (PCtxLen)
      *PCtxLen = CtxLen;
    if (PCtxAlign)
      *PCtxAlign = CtxAlign;
  }
  void
  erase(Instruction *I) {
    if (!I)
      return;
    if (I->use_empty()) {
      I->dropAllReferences();
      I->eraseFromParent();
    }
    else
      dumpUsers(I);
  }
  void
  erase(ConstantExpr *I) {
    if (!I)
      return;
    if (I->use_empty()) {
      I->dropAllReferences();
      I->destroyConstant();
    } else
      dumpUsers(I);
  }
  void
  erase(Function *F) {
    if (!F)
      return;
    if (!F->use_empty()) {
      dumpUsers(F);
      return;
    }
    F->dropAllReferences();
    auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
    CG.removeFunctionFromModule(new CallGraphNode(F));
  }

  llvm::PointerType* getOCLClkEventType() {
    return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
        SPIRAS_Global);
  }

  llvm::PointerType* getOCLClkEventPtrType() {
    return PointerType::get(getOCLClkEventType(), SPIRAS_Generic);
  }

  bool isOCLClkEventPtrType(Type *T) {
    if (auto PT = dyn_cast<PointerType>(T))
      return isPointerToOpaqueStructType(
        PT->getElementType(), SPIR_TYPE_NAME_CLK_EVENT_T);
    return false;
  }

  llvm::Constant* getOCLNullClkEventPtr() {
    return Constant::getNullValue(getOCLClkEventPtrType());
  }

  void dumpGetBlockInvokeUsers(StringRef Prompt) {
    DEBUG(dbgs() << Prompt);
    dumpUsers(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
  }
};

char SPIRVLowerOCLBlocks::ID = 0;
}

INITIALIZE_PASS_BEGIN(SPIRVLowerOCLBlocks, "spvblocks",
    "SPIR-V lower OCL blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
//INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
INITIALIZE_PASS_END(SPIRVLowerOCLBlocks, "spvblocks",
    "SPIR-V lower OCL blocks", false, false)

ModulePass *llvm::createSPIRVLowerOCLBlocks() {
  return new SPIRVLowerOCLBlocks();
}

#endif /* OCLLOWERBLOCKS_H_ */