//===- SPIRVEntry.cpp - Base Class for SPIR-V Entities -----------*- C++ -*-===//
//
//                     The LLVM/SPIRV Translator
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
// Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal with the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimers.
// Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimers in the documentation
// and/or other materials provided with the distribution.
// Neither the names of Advanced Micro Devices, Inc., nor the names of its
// contributors may be used to endorse or promote products derived from this
// Software without specific prior written permission.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
// THE SOFTWARE.
//
//===----------------------------------------------------------------------===//
/// \file
///
/// This file implements base class for SPIR-V entities.
///
//===----------------------------------------------------------------------===//

#include "SPIRVEntry.h"
#include "SPIRVDebug.h"
#include "SPIRVType.h"
#include "SPIRVFunction.h"
#include "SPIRVBasicBlock.h"
#include "SPIRVInstruction.h"
#include "SPIRVDecorate.h"
#include "SPIRVStream.h"

#include <algorithm>
#include <map>
#include <set>
#include <sstream>
#include <string>
#include <utility>

using namespace SPIRV;

namespace SPIRV{

template<typename T>
SPIRVEntry* create() {
  return new T();
}

SPIRVEntry *
SPIRVEntry::create(Op OpCode) {
  typedef SPIRVEntry *(*SPIRVFactoryTy)();
  struct TableEntry {
    Op Opn;
    SPIRVFactoryTy Factory;
    operator std::pair<const Op, SPIRVFactoryTy>() {
      return std::make_pair(Opn, Factory);
    }
  };

  static TableEntry Table[] = {
#define _SPIRV_OP(x,...) {Op##x, &SPIRV::create<SPIRV##x>},
#include "SPIRVOpCodeEnum.h"
#undef _SPIRV_OP
  };

  typedef std::map<Op, SPIRVFactoryTy> OpToFactoryMapTy;
  static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
      std::end(Table));

  OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode);
  if (Loc != OpToFactoryMap.end())
    return Loc->second();

  SPIRVDBG(spvdbgs() << "No factory for OpCode " << (unsigned)OpCode << '\n';)
  assert (0 && "Not implemented");
  return 0;
}

std::unique_ptr<SPIRV::SPIRVEntry>
SPIRVEntry::create_unique(Op OC) {
  return std::unique_ptr<SPIRVEntry>(create(OC));
}

std::unique_ptr<SPIRV::SPIRVExtInst>
SPIRVEntry::create_unique(SPIRVExtInstSetKind Set,
    unsigned ExtOp) {
  return std::unique_ptr<SPIRVExtInst>(new SPIRVExtInst(Set, ExtOp));
}

SPIRVErrorLog &
SPIRVEntry::getErrorLog()const {
  return Module->getErrorLog();
}

bool
SPIRVEntry::exist(SPIRVId TheId)const {
  return Module->exist(TheId);
}

SPIRVEntry *
SPIRVEntry::getOrCreate(SPIRVId TheId)const {
  SPIRVEntry *Entry = nullptr;
  bool Found = Module->exist(TheId, &Entry);
  if (!Found)
    return Module->addForward(TheId, nullptr);
  return Entry;
}

SPIRVValue *
SPIRVEntry::getValue(SPIRVId TheId)const {
  return get<SPIRVValue>(TheId);
}

SPIRVType *
SPIRVEntry::getValueType(SPIRVId TheId)const {
  return get<SPIRVValue>(TheId)->getType();
}

SPIRVEncoder
SPIRVEntry::getEncoder(spv_ostream &O)const{
  return SPIRVEncoder(O);
}

SPIRVDecoder
SPIRVEntry::getDecoder(std::istream& I){
  return SPIRVDecoder(I, *Module);
}

void
SPIRVEntry::setWordCount(SPIRVWord TheWordCount){
  WordCount = TheWordCount;
}

void
SPIRVEntry::setName(const std::string& TheName) {
  Name = TheName;
  SPIRVDBG(spvdbgs() << "Set name for obj " << Id << " " <<
    Name << '\n');
}

void
SPIRVEntry::setModule(SPIRVModule *TheModule) {
  assert(TheModule && "Invalid module");
  if (TheModule == Module)
    return;
  assert(Module == NULL && "Cannot change owner of entry");
  Module = TheModule;
}

void
SPIRVEntry::encode(spv_ostream &O) const {
  assert (0 && "Not implemented");
}

void
SPIRVEntry::encodeName(spv_ostream &O) const {
  if (!Name.empty())
    O << SPIRVName(this, Name);
}

void
SPIRVEntry::encodeAll(spv_ostream &O) const {
  encodeWordCountOpCode(O);
  encode(O);
  encodeChildren(O);
}

void
SPIRVEntry::encodeChildren(spv_ostream &O)const {
}

void
SPIRVEntry::encodeWordCountOpCode(spv_ostream &O) const {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
  if (SPIRVUseTextFormat) {
    getEncoder(O) << WordCount << OpCode;
    return;
  }
#endif
  getEncoder(O) << mkWord(WordCount, OpCode);
}
// Read words from SPIRV binary and create members for SPIRVEntry.
// The word count and op code has already been read before calling this
// function for creating the SPIRVEntry. Therefore the input stream only
// contains the remaining part of the words for the SPIRVEntry.
void
SPIRVEntry::decode(std::istream &I) {
  assert (0 && "Not implemented");
}

std::vector<SPIRVValue *>
SPIRVEntry::getValues(const std::vector<SPIRVId>& IdVec)const {
  std::vector<SPIRVValue *> ValueVec;
  for (auto i:IdVec)
    ValueVec.push_back(getValue(i));
  return ValueVec;
}

std::vector<SPIRVType *>
SPIRVEntry::getValueTypes(const std::vector<SPIRVId>& IdVec)const {
  std::vector<SPIRVType *> TypeVec;
  for (auto i:IdVec)
    TypeVec.push_back(getValue(i)->getType());
  return TypeVec;
}

std::vector<SPIRVId>
SPIRVEntry::getIds(const std::vector<SPIRVValue *> ValueVec)const {
  std::vector<SPIRVId> IdVec;
  for (auto i:ValueVec)
    IdVec.push_back(i->getId());
  return IdVec;
}

SPIRVEntry *
SPIRVEntry::getEntry(SPIRVId TheId) const {
  return Module->getEntry(TheId);
}

void
SPIRVEntry::validateFunctionControlMask(SPIRVWord TheFCtlMask)
  const {
  SPIRVCK(isValidFunctionControlMask(TheFCtlMask),
      InvalidFunctionControlMask, "");
}

void
SPIRVEntry::validateValues(const std::vector<SPIRVId> &Ids)const {
  for (auto I:Ids)
    getValue(I)->validate();
}

void
SPIRVEntry::validateBuiltin(SPIRVWord TheSet, SPIRVWord Index)const {
  (void) TheSet;
  (void) Index;
  assert(TheSet != SPIRVWORD_MAX && Index != SPIRVWORD_MAX &&
      "Invalid builtin");
}

void
SPIRVEntry::addDecorate(const SPIRVDecorate *Dec){
  Decorates.insert(std::make_pair(Dec->getDecorateKind(), Dec));
  Module->addDecorate(Dec);
  SPIRVDBG(spvdbgs() << "[addDecorate] " << *Dec << '\n';)
}

void
SPIRVEntry::addDecorate(Decoration Kind) {
  addDecorate(new SPIRVDecorate(Kind, this));
}

void
SPIRVEntry::addDecorate(Decoration Kind, SPIRVWord Literal) {
  addDecorate(new SPIRVDecorate(Kind, this, Literal));
}

void
SPIRVEntry::eraseDecorate(Decoration Dec){
  Decorates.erase(Dec);
}

void
SPIRVEntry::takeDecorates(SPIRVEntry *E){
  Decorates = std::move(E->Decorates);
  SPIRVDBG(spvdbgs() << "[takeDecorates] " << Id << '\n';)
}

void
SPIRVEntry::setLine(SPIRVLine *L){
  Line = L;
  L->setTargetId(Id);
  SPIRVDBG(spvdbgs() << "[setLine] " << *L << '\n';)
}

void
SPIRVEntry::takeLine(SPIRVEntry *E){
  Line = E->Line;
  if (Line == nullptr)
    return;
  Line->setTargetId(Id);
  E->Line = nullptr;
}

void
SPIRVEntry::addMemberDecorate(const SPIRVMemberDecorate *Dec){
  assert(canHaveMemberDecorates() && MemberDecorates.find(Dec->getPair()) ==
      MemberDecorates.end());
  MemberDecorates[Dec->getPair()] = Dec;
  Module->addDecorate(Dec);
  SPIRVDBG(spvdbgs() << "[addMemberDecorate] " << *Dec << '\n';)
}

void
SPIRVEntry::addMemberDecorate(SPIRVWord MemberNumber, Decoration Kind) {
  addMemberDecorate(new SPIRVMemberDecorate(Kind, MemberNumber, this));
}

void
SPIRVEntry::addMemberDecorate(SPIRVWord MemberNumber, Decoration Kind,
    SPIRVWord Literal) {
  addMemberDecorate(new SPIRVMemberDecorate(Kind, MemberNumber, this, Literal));
}

void
SPIRVEntry::eraseMemberDecorate(SPIRVWord MemberNumber, Decoration Dec){
  MemberDecorates.erase(std::make_pair(MemberNumber, Dec));
}

void
SPIRVEntry::takeMemberDecorates(SPIRVEntry *E){
  MemberDecorates = std::move(E->MemberDecorates);
  SPIRVDBG(spvdbgs() << "[takeMemberDecorates] " << Id << '\n';)
}

void
SPIRVEntry::takeAnnotations(SPIRVForward *E){
  Module->setName(this, E->getName());
  takeDecorates(E);
  takeMemberDecorates(E);
  takeLine(E);
  if (OpCode == OpFunction)
    static_cast<SPIRVFunction *>(this)->takeExecutionModes(E);
}

// Check if an entry has Kind of decoration and get the literal of the
// first decoration of such kind at Index.
bool
SPIRVEntry::hasDecorate(Decoration Kind, size_t Index, SPIRVWord *Result)const {
  DecorateMapType::const_iterator Loc = Decorates.find(Kind);
  if (Loc == Decorates.end())
    return false;
  if (Result)
    *Result = Loc->second->getLiteral(Index);
  return true;
}

// Get literals of all decorations of Kind at Index.
std::set<SPIRVWord>
SPIRVEntry::getDecorate(Decoration Kind, size_t Index) const {
  auto Range = Decorates.equal_range(Kind);
  std::set<SPIRVWord> Value;
  for (auto I = Range.first, E = Range.second; I != E; ++I) {
    assert(Index < I->second->getLiteralCount() && "Invalid index");
    Value.insert(I->second->getLiteral(Index));
  }
  return Value;
}

bool
SPIRVEntry::hasLinkageType() const {
  return OpCode == OpFunction || OpCode == OpVariable;
}

void
SPIRVEntry::encodeDecorate(spv_ostream &O) const {
  for (auto& i:Decorates)
    O << *i.second;
}

SPIRVLinkageTypeKind
SPIRVEntry::getLinkageType() const {
  assert(hasLinkageType());
  DecorateMapType::const_iterator Loc = Decorates.find(DecorationLinkageAttributes);
  if (Loc == Decorates.end())
    return LinkageTypeInternal;
  return static_cast<const SPIRVDecorateLinkageAttr*>(Loc->second)->getLinkageType();
}

void
SPIRVEntry::setLinkageType(SPIRVLinkageTypeKind LT) {
  assert(isValid(LT));
  assert(hasLinkageType());
  addDecorate(new SPIRVDecorateLinkageAttr(this, Name, LT));
}

void
SPIRVEntry::updateModuleVersion() const {
  if (!Module)
    return;

  Module->setMinSPIRVVersion(getRequiredSPIRVVersion());
}

spv_ostream &
operator<<(spv_ostream &O, const SPIRVEntry &E) {
  E.validate();
  E.encodeAll(O);
  O << SPIRVNL();
  return O;
}

std::istream &
operator>>(std::istream &I, SPIRVEntry &E) {
  E.decode(I);
  return I;
}

SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
  SPIRVExecutionModelKind TheExecModel, SPIRVId TheId,
  const std::string &TheName)
  :SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
   getSizeInWords(TheName) + 3), ExecModel(TheExecModel), Name(TheName){
}

void
SPIRVEntryPoint::encode(spv_ostream &O) const {
  getEncoder(O) << ExecModel << Target << Name;
}

void
SPIRVEntryPoint::decode(std::istream &I) {
  getDecoder(I) >> ExecModel >> Target >> Name;
  Module->setName(getOrCreateTarget(), Name);
  Module->addEntryPoint(ExecModel, Target);
}

void
SPIRVExecutionMode::encode(spv_ostream &O) const {
  getEncoder(O) << Target << ExecMode << WordLiterals;
}

void
SPIRVExecutionMode::decode(std::istream &I) {
  getDecoder(I) >> Target >> ExecMode;
  switch(ExecMode) {
  case ExecutionModeLocalSize:
  case ExecutionModeLocalSizeHint:
    WordLiterals.resize(3);
    break;
  case ExecutionModeInvocations:
  case ExecutionModeOutputVertices:
  case ExecutionModeVecTypeHint:
    WordLiterals.resize(1);
    break;
  default:
    // Do nothing. Keep this to avoid VS2013 warning.
    break;
  }
  getDecoder(I) >> WordLiterals;
  getOrCreateTarget()->addExecutionMode(this);
}

SPIRVForward *
SPIRVAnnotationGeneric::getOrCreateTarget()const {
  SPIRVEntry *Entry = nullptr;
  bool Found = Module->exist(Target, &Entry);
  assert((!Found || Entry->getOpCode() == OpForward) &&
      "Annotations only allowed on forward");
  if (!Found)
    Entry = Module->addForward(Target, nullptr);
  return static_cast<SPIRVForward *>(Entry);
}

SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string& TheStr)
  :SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr){
}

void
SPIRVName::encode(spv_ostream &O) const {
  getEncoder(O) << Target << Str;
}

void
SPIRVName::decode(std::istream &I) {
  getDecoder(I) >> Target >> Str;
  Module->setName(getOrCreateTarget(), Str);
}

void
SPIRVName::validate() const {
  assert(WordCount == getSizeInWords(Str) + 2 && "Incorrect word count");
}

_SPIRV_IMP_ENCDEC2(SPIRVString, Id, Str)
_SPIRV_IMP_ENCDEC3(SPIRVMemberName, Target, MemberNumber, Str)

void
SPIRVLine::encode(spv_ostream &O) const {
  getEncoder(O) << Target << FileName << Line << Column;
}

void
SPIRVLine::decode(std::istream &I) {
  getDecoder(I) >> Target >> FileName >> Line >> Column;
  Module->addLine(getOrCreateTarget(), get<SPIRVString>(FileName), Line, Column);
}

void
SPIRVLine::validate() const {
  assert(OpCode == OpLine);
  assert(WordCount == 5);
  assert(get<SPIRVEntry>(Target));
  assert(get<SPIRVEntry>(FileName)->getOpCode() == OpString);
  assert(Line != SPIRVWORD_MAX);
  assert(Column != SPIRVWORD_MAX);
}

void
SPIRVMemberName::validate() const {
  assert(OpCode == OpMemberName);
  assert(WordCount == getSizeInWords(Str) + FixedWC);
  assert(get<SPIRVEntry>(Target)->getOpCode() == OpTypeStruct);
  assert(MemberNumber < get<SPIRVTypeStruct>(Target)->getStructMemberCount());
}

SPIRVExtInstImport::SPIRVExtInstImport(SPIRVModule *TheModule, SPIRVId TheId,
    const std::string &TheStr):
  SPIRVEntry(TheModule, 2 + getSizeInWords(TheStr), OC, TheId), Str(TheStr){
  validate();
}

void
SPIRVExtInstImport::encode(spv_ostream &O) const {
  getEncoder(O) << Id << Str;
}

void
SPIRVExtInstImport::decode(std::istream &I) {
  getDecoder(I) >> Id >> Str;
  Module->importBuiltinSetWithId(Str, Id);
}

void
SPIRVExtInstImport::validate() const {
  SPIRVEntry::validate();
  assert(!Str.empty() && "Invalid builtin set");
}

void
SPIRVMemoryModel::encode(spv_ostream &O) const {
  getEncoder(O) << Module->getAddressingModel() <<
      Module->getMemoryModel();
}

void
SPIRVMemoryModel::decode(std::istream &I) {
  SPIRVAddressingModelKind AddrModel;
  SPIRVMemoryModelKind MemModel;
  getDecoder(I) >> AddrModel >> MemModel;
  Module->setAddressingModel(AddrModel);
  Module->setMemoryModel(MemModel);
}

void
SPIRVMemoryModel::validate() const {
  auto AM = Module->getAddressingModel();
  auto MM = Module->getMemoryModel();
  SPIRVCK(isValid(AM), InvalidAddressingModel, "Actual is "+AM );
  SPIRVCK(isValid(MM), InvalidMemoryModel, "Actual is "+MM);
}

void
SPIRVSource::encode(spv_ostream &O) const {
  SPIRVWord Ver = SPIRVWORD_MAX;
  auto Language = Module->getSourceLanguage(&Ver);
  getEncoder(O) << Language << Ver;
}

void
SPIRVSource::decode(std::istream &I) {
  SourceLanguage Lang = SourceLanguageUnknown;
  SPIRVWord Ver = SPIRVWORD_MAX;
  getDecoder(I) >> Lang >> Ver;
  Module->setSourceLanguage(Lang, Ver);
}

SPIRVSourceExtension::SPIRVSourceExtension(SPIRVModule *M,
    const std::string &SS)
  :SPIRVEntryNoId(M, 1 + getSizeInWords(SS)), S(SS){}

void
SPIRVSourceExtension::encode(spv_ostream &O) const {
  getEncoder(O) << S;
}

void
SPIRVSourceExtension::decode(std::istream &I) {
  getDecoder(I) >> S;
  Module->getSourceExtension().insert(S);
}

SPIRVExtension::SPIRVExtension(SPIRVModule *M, const std::string &SS)
  :SPIRVEntryNoId(M, 1 + getSizeInWords(SS)), S(SS){}

void
SPIRVExtension::encode(spv_ostream &O) const {
  getEncoder(O) << S;
}

void
SPIRVExtension::decode(std::istream &I) {
  getDecoder(I) >> S;
  Module->getExtension().insert(S);
}

SPIRVCapability::SPIRVCapability(SPIRVModule *M, SPIRVCapabilityKind K)
  :SPIRVEntryNoId(M, 2), Kind(K){
  updateModuleVersion();
}

void
SPIRVCapability::encode(spv_ostream &O) const {
  getEncoder(O) << Kind;
}

void
SPIRVCapability::decode(std::istream &I) {
  getDecoder(I) >> Kind;
  Module->addCapability(Kind);
}

} // namespace SPIRV