// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/scalar_replacement_pass.h"

#include <algorithm>
#include <queue>
#include <tuple>
#include <utility>

#include "source/enum_string_mapping.h"
#include "source/extensions.h"
#include "source/opt/reflect.h"
#include "source/opt/types.h"
#include "source/util/make_unique.h"

namespace spvtools {
namespace opt {

Pass::Status ScalarReplacementPass::Process() {
  Status status = Status::SuccessWithoutChange;
  for (auto& f : *get_module()) {
    Status functionStatus = ProcessFunction(&f);
    if (functionStatus == Status::Failure)
      return functionStatus;
    else if (functionStatus == Status::SuccessWithChange)
      status = functionStatus;
  }

  return status;
}

Pass::Status ScalarReplacementPass::ProcessFunction(Function* function) {
  std::queue<Instruction*> worklist;
  BasicBlock& entry = *function->begin();
  for (auto iter = entry.begin(); iter != entry.end(); ++iter) {
    // Function storage class OpVariables must appear as the first instructions
    // of the entry block.
    if (iter->opcode() != SpvOpVariable) break;

    Instruction* varInst = &*iter;
    if (CanReplaceVariable(varInst)) {
      worklist.push(varInst);
    }
  }

  Status status = Status::SuccessWithoutChange;
  while (!worklist.empty()) {
    Instruction* varInst = worklist.front();
    worklist.pop();

    if (!ReplaceVariable(varInst, &worklist))
      return Status::Failure;
    else
      status = Status::SuccessWithChange;
  }

  return status;
}

bool ScalarReplacementPass::ReplaceVariable(
    Instruction* inst, std::queue<Instruction*>* worklist) {
  std::vector<Instruction*> replacements;
  CreateReplacementVariables(inst, &replacements);

  std::vector<Instruction*> dead;
  dead.push_back(inst);
  if (!get_def_use_mgr()->WhileEachUser(
          inst, [this, &replacements, &dead](Instruction* user) {
            if (!IsAnnotationInst(user->opcode())) {
              switch (user->opcode()) {
                case SpvOpLoad:
                  ReplaceWholeLoad(user, replacements);
                  dead.push_back(user);
                  break;
                case SpvOpStore:
                  ReplaceWholeStore(user, replacements);
                  dead.push_back(user);
                  break;
                case SpvOpAccessChain:
                case SpvOpInBoundsAccessChain:
                  if (!ReplaceAccessChain(user, replacements)) return false;
                  dead.push_back(user);
                  break;
                case SpvOpName:
                case SpvOpMemberName:
                  break;
                default:
                  assert(false && "Unexpected opcode");
                  break;
              }
            }
            return true;
          }))
    return false;

  // Clean up some dead code.
  while (!dead.empty()) {
    Instruction* toKill = dead.back();
    dead.pop_back();
    context()->KillInst(toKill);
  }

  // Attempt to further scalarize.
  for (auto var : replacements) {
    if (var->opcode() == SpvOpVariable) {
      if (get_def_use_mgr()->NumUsers(var) == 0) {
        context()->KillInst(var);
      } else if (CanReplaceVariable(var)) {
        worklist->push(var);
      }
    }
  }

  return true;
}

void ScalarReplacementPass::ReplaceWholeLoad(
    Instruction* load, const std::vector<Instruction*>& replacements) {
  // Replaces the load of the entire composite with a load from each replacement
  // variable followed by a composite construction.
  BasicBlock* block = context()->get_instr_block(load);
  std::vector<Instruction*> loads;
  loads.reserve(replacements.size());
  BasicBlock::iterator where(load);
  for (auto var : replacements) {
    // Create a load of each replacement variable.
    if (var->opcode() != SpvOpVariable) {
      loads.push_back(var);
      continue;
    }

    Instruction* type = GetStorageType(var);
    uint32_t loadId = TakeNextId();
    std::unique_ptr<Instruction> newLoad(
        new Instruction(context(), SpvOpLoad, type->result_id(), loadId,
                        std::initializer_list<Operand>{
                            {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
    // Copy memory access attributes which start at index 1. Index 0 is the
    // pointer to load.
    for (uint32_t i = 1; i < load->NumInOperands(); ++i) {
      Operand copy(load->GetInOperand(i));
      newLoad->AddOperand(std::move(copy));
    }
    where = where.InsertBefore(std::move(newLoad));
    get_def_use_mgr()->AnalyzeInstDefUse(&*where);
    context()->set_instr_block(&*where, block);
    loads.push_back(&*where);
  }

  // Construct a new composite.
  uint32_t compositeId = TakeNextId();
  where = load;
  std::unique_ptr<Instruction> compositeConstruct(new Instruction(
      context(), SpvOpCompositeConstruct, load->type_id(), compositeId, {}));
  for (auto l : loads) {
    Operand op(SPV_OPERAND_TYPE_ID,
               std::initializer_list<uint32_t>{l->result_id()});
    compositeConstruct->AddOperand(std::move(op));
  }
  where = where.InsertBefore(std::move(compositeConstruct));
  get_def_use_mgr()->AnalyzeInstDefUse(&*where);
  context()->set_instr_block(&*where, block);
  context()->ReplaceAllUsesWith(load->result_id(), compositeId);
}

void ScalarReplacementPass::ReplaceWholeStore(
    Instruction* store, const std::vector<Instruction*>& replacements) {
  // Replaces a store to the whole composite with a series of extract and stores
  // to each element.
  uint32_t storeInput = store->GetSingleWordInOperand(1u);
  BasicBlock* block = context()->get_instr_block(store);
  BasicBlock::iterator where(store);
  uint32_t elementIndex = 0;
  for (auto var : replacements) {
    // Create the extract.
    if (var->opcode() != SpvOpVariable) {
      elementIndex++;
      continue;
    }

    Instruction* type = GetStorageType(var);
    uint32_t extractId = TakeNextId();
    std::unique_ptr<Instruction> extract(new Instruction(
        context(), SpvOpCompositeExtract, type->result_id(), extractId,
        std::initializer_list<Operand>{
            {SPV_OPERAND_TYPE_ID, {storeInput}},
            {SPV_OPERAND_TYPE_LITERAL_INTEGER, {elementIndex++}}}));
    auto iter = where.InsertBefore(std::move(extract));
    get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
    context()->set_instr_block(&*iter, block);

    // Create the store.
    std::unique_ptr<Instruction> newStore(
        new Instruction(context(), SpvOpStore, 0, 0,
                        std::initializer_list<Operand>{
                            {SPV_OPERAND_TYPE_ID, {var->result_id()}},
                            {SPV_OPERAND_TYPE_ID, {extractId}}}));
    // Copy memory access attributes which start at index 2. Index 0 is the
    // pointer and index 1 is the data.
    for (uint32_t i = 2; i < store->NumInOperands(); ++i) {
      Operand copy(store->GetInOperand(i));
      newStore->AddOperand(std::move(copy));
    }
    iter = where.InsertBefore(std::move(newStore));
    get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
    context()->set_instr_block(&*iter, block);
  }
}

bool ScalarReplacementPass::ReplaceAccessChain(
    Instruction* chain, const std::vector<Instruction*>& replacements) {
  // Replaces the access chain with either another access chain (with one fewer
  // indexes) or a direct use of the replacement variable.
  uint32_t indexId = chain->GetSingleWordInOperand(1u);
  const Instruction* index = get_def_use_mgr()->GetDef(indexId);
  size_t indexValue = GetConstantInteger(index);
  if (indexValue > replacements.size()) {
    // Out of bounds access, this is illegal IR.
    return false;
  } else {
    const Instruction* var = replacements[indexValue];
    if (chain->NumInOperands() > 2) {
      // Replace input access chain with another access chain.
      BasicBlock::iterator chainIter(chain);
      uint32_t replacementId = TakeNextId();
      std::unique_ptr<Instruction> replacementChain(new Instruction(
          context(), chain->opcode(), chain->type_id(), replacementId,
          std::initializer_list<Operand>{
              {SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
      // Add the remaining indexes.
      for (uint32_t i = 2; i < chain->NumInOperands(); ++i) {
        Operand copy(chain->GetInOperand(i));
        replacementChain->AddOperand(std::move(copy));
      }
      auto iter = chainIter.InsertBefore(std::move(replacementChain));
      get_def_use_mgr()->AnalyzeInstDefUse(&*iter);
      context()->set_instr_block(&*iter, context()->get_instr_block(chain));
      context()->ReplaceAllUsesWith(chain->result_id(), replacementId);
    } else {
      // Replace with a use of the variable.
      context()->ReplaceAllUsesWith(chain->result_id(), var->result_id());
    }
  }

  return true;
}

void ScalarReplacementPass::CreateReplacementVariables(
    Instruction* inst, std::vector<Instruction*>* replacements) {
  Instruction* type = GetStorageType(inst);

  std::unique_ptr<std::unordered_set<uint64_t>> components_used =
      GetUsedComponents(inst);

  uint32_t elem = 0;
  switch (type->opcode()) {
    case SpvOpTypeStruct:
      type->ForEachInOperand(
          [this, inst, &elem, replacements, &components_used](uint32_t* id) {
            if (!components_used || components_used->count(elem)) {
              CreateVariable(*id, inst, elem, replacements);
            } else {
              replacements->push_back(CreateNullConstant(*id));
            }
            elem++;
          });
      break;
    case SpvOpTypeArray:
      for (uint32_t i = 0; i != GetArrayLength(type); ++i) {
        if (!components_used || components_used->count(i)) {
          CreateVariable(type->GetSingleWordInOperand(0u), inst, i,
                         replacements);
        } else {
          replacements->push_back(
              CreateNullConstant(type->GetSingleWordInOperand(0u)));
        }
      }
      break;

    case SpvOpTypeMatrix:
    case SpvOpTypeVector:
      for (uint32_t i = 0; i != GetNumElements(type); ++i) {
        CreateVariable(type->GetSingleWordInOperand(0u), inst, i, replacements);
      }
      break;

    default:
      assert(false && "Unexpected type.");
      break;
  }

  TransferAnnotations(inst, replacements);
}

void ScalarReplacementPass::TransferAnnotations(
    const Instruction* source, std::vector<Instruction*>* replacements) {
  // Only transfer invariant and restrict decorations on the variable. There are
  // no type or member decorations that are necessary to transfer.
  for (auto inst :
       get_decoration_mgr()->GetDecorationsFor(source->result_id(), false)) {
    assert(inst->opcode() == SpvOpDecorate);
    uint32_t decoration = inst->GetSingleWordInOperand(1u);
    if (decoration == SpvDecorationInvariant ||
        decoration == SpvDecorationRestrict) {
      for (auto var : *replacements) {
        std::unique_ptr<Instruction> annotation(
            new Instruction(context(), SpvOpDecorate, 0, 0,
                            std::initializer_list<Operand>{
                                {SPV_OPERAND_TYPE_ID, {var->result_id()}},
                                {SPV_OPERAND_TYPE_DECORATION, {decoration}}}));
        for (uint32_t i = 2; i < inst->NumInOperands(); ++i) {
          Operand copy(inst->GetInOperand(i));
          annotation->AddOperand(std::move(copy));
        }
        context()->AddAnnotationInst(std::move(annotation));
        get_def_use_mgr()->AnalyzeInstUse(&*--context()->annotation_end());
      }
    }
  }
}

void ScalarReplacementPass::CreateVariable(
    uint32_t typeId, Instruction* varInst, uint32_t index,
    std::vector<Instruction*>* replacements) {
  uint32_t ptrId = GetOrCreatePointerType(typeId);
  uint32_t id = TakeNextId();
  std::unique_ptr<Instruction> variable(new Instruction(
      context(), SpvOpVariable, ptrId, id,
      std::initializer_list<Operand>{
          {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}}}));

  BasicBlock* block = context()->get_instr_block(varInst);
  block->begin().InsertBefore(std::move(variable));
  Instruction* inst = &*block->begin();

  // If varInst was initialized, make sure to initialize its replacement.
  GetOrCreateInitialValue(varInst, index, inst);
  get_def_use_mgr()->AnalyzeInstDefUse(inst);
  context()->set_instr_block(inst, block);

  replacements->push_back(inst);
}

uint32_t ScalarReplacementPass::GetOrCreatePointerType(uint32_t id) {
  auto iter = pointee_to_pointer_.find(id);
  if (iter != pointee_to_pointer_.end()) return iter->second;

  analysis::Type* pointeeTy;
  std::unique_ptr<analysis::Pointer> pointerTy;
  std::tie(pointeeTy, pointerTy) =
      context()->get_type_mgr()->GetTypeAndPointerType(id,
                                                       SpvStorageClassFunction);
  uint32_t ptrId = 0;
  if (pointeeTy->IsUniqueType()) {
    // Non-ambiguous type, just ask the type manager for an id.
    ptrId = context()->get_type_mgr()->GetTypeInstruction(pointerTy.get());
    pointee_to_pointer_[id] = ptrId;
    return ptrId;
  }

  // Ambiguous type. We must perform a linear search to try and find the right
  // type.
  for (auto global : context()->types_values()) {
    if (global.opcode() == SpvOpTypePointer &&
        global.GetSingleWordInOperand(0u) == SpvStorageClassFunction &&
        global.GetSingleWordInOperand(1u) == id) {
      if (get_decoration_mgr()->GetDecorationsFor(id, false).empty()) {
        // Only reuse a decoration-less pointer of the correct type.
        ptrId = global.result_id();
        break;
      }
    }
  }

  if (ptrId != 0) {
    pointee_to_pointer_[id] = ptrId;
    return ptrId;
  }

  ptrId = TakeNextId();
  context()->AddType(MakeUnique<Instruction>(
      context(), SpvOpTypePointer, 0, ptrId,
      std::initializer_list<Operand>{
          {SPV_OPERAND_TYPE_STORAGE_CLASS, {SpvStorageClassFunction}},
          {SPV_OPERAND_TYPE_ID, {id}}}));
  Instruction* ptr = &*--context()->types_values_end();
  get_def_use_mgr()->AnalyzeInstDefUse(ptr);
  pointee_to_pointer_[id] = ptrId;
  // Register with the type manager if necessary.
  context()->get_type_mgr()->RegisterType(ptrId, *pointerTy);

  return ptrId;
}

void ScalarReplacementPass::GetOrCreateInitialValue(Instruction* source,
                                                    uint32_t index,
                                                    Instruction* newVar) {
  assert(source->opcode() == SpvOpVariable);
  if (source->NumInOperands() < 2) return;

  uint32_t initId = source->GetSingleWordInOperand(1u);
  uint32_t storageId = GetStorageType(newVar)->result_id();
  Instruction* init = get_def_use_mgr()->GetDef(initId);
  uint32_t newInitId = 0;
  // TODO(dnovillo): Refactor this with constant propagation.
  if (init->opcode() == SpvOpConstantNull) {
    // Initialize to appropriate NULL.
    auto iter = type_to_null_.find(storageId);
    if (iter == type_to_null_.end()) {
      newInitId = TakeNextId();
      type_to_null_[storageId] = newInitId;
      context()->AddGlobalValue(
          MakeUnique<Instruction>(context(), SpvOpConstantNull, storageId,
                                  newInitId, std::initializer_list<Operand>{}));
      Instruction* newNull = &*--context()->types_values_end();
      get_def_use_mgr()->AnalyzeInstDefUse(newNull);
    } else {
      newInitId = iter->second;
    }
  } else if (IsSpecConstantInst(init->opcode())) {
    // Create a new constant extract.
    newInitId = TakeNextId();
    context()->AddGlobalValue(MakeUnique<Instruction>(
        context(), SpvOpSpecConstantOp, storageId, newInitId,
        std::initializer_list<Operand>{
            {SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER, {SpvOpCompositeExtract}},
            {SPV_OPERAND_TYPE_ID, {init->result_id()}},
            {SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}}}));
    Instruction* newSpecConst = &*--context()->types_values_end();
    get_def_use_mgr()->AnalyzeInstDefUse(newSpecConst);
  } else if (init->opcode() == SpvOpConstantComposite) {
    // Get the appropriate index constant.
    newInitId = init->GetSingleWordInOperand(index);
    Instruction* element = get_def_use_mgr()->GetDef(newInitId);
    if (element->opcode() == SpvOpUndef) {
      // Undef is not a valid initializer for a variable.
      newInitId = 0;
    }
  } else {
    assert(false);
  }

  if (newInitId != 0) {
    newVar->AddOperand({SPV_OPERAND_TYPE_ID, {newInitId}});
  }
}

size_t ScalarReplacementPass::GetIntegerLiteral(const Operand& op) const {
  assert(op.words.size() <= 2);
  size_t len = 0;
  for (uint32_t i = 0; i != op.words.size(); ++i) {
    len |= (op.words[i] << (32 * i));
  }
  return len;
}

size_t ScalarReplacementPass::GetConstantInteger(
    const Instruction* constant) const {
  assert(get_def_use_mgr()->GetDef(constant->type_id())->opcode() ==
         SpvOpTypeInt);
  assert(constant->opcode() == SpvOpConstant ||
         constant->opcode() == SpvOpConstantNull);
  if (constant->opcode() == SpvOpConstantNull) {
    return 0;
  }

  const Operand& op = constant->GetInOperand(0u);
  return GetIntegerLiteral(op);
}

size_t ScalarReplacementPass::GetArrayLength(
    const Instruction* arrayType) const {
  assert(arrayType->opcode() == SpvOpTypeArray);
  const Instruction* length =
      get_def_use_mgr()->GetDef(arrayType->GetSingleWordInOperand(1u));
  return GetConstantInteger(length);
}

size_t ScalarReplacementPass::GetNumElements(const Instruction* type) const {
  assert(type->opcode() == SpvOpTypeVector ||
         type->opcode() == SpvOpTypeMatrix);
  const Operand& op = type->GetInOperand(1u);
  assert(op.words.size() <= 2);
  size_t len = 0;
  for (uint32_t i = 0; i != op.words.size(); ++i) {
    len |= (op.words[i] << (32 * i));
  }
  return len;
}

bool ScalarReplacementPass::IsSpecConstant(uint32_t id) const {
  const Instruction* inst = get_def_use_mgr()->GetDef(id);
  assert(inst);
  return spvOpcodeIsSpecConstant(inst->opcode());
}

Instruction* ScalarReplacementPass::GetStorageType(
    const Instruction* inst) const {
  assert(inst->opcode() == SpvOpVariable);

  uint32_t ptrTypeId = inst->type_id();
  uint32_t typeId =
      get_def_use_mgr()->GetDef(ptrTypeId)->GetSingleWordInOperand(1u);
  return get_def_use_mgr()->GetDef(typeId);
}

bool ScalarReplacementPass::CanReplaceVariable(
    const Instruction* varInst) const {
  assert(varInst->opcode() == SpvOpVariable);

  // Can only replace function scope variables.
  if (varInst->GetSingleWordInOperand(0u) != SpvStorageClassFunction)
    return false;

  if (!CheckTypeAnnotations(get_def_use_mgr()->GetDef(varInst->type_id())))
    return false;

  const Instruction* typeInst = GetStorageType(varInst);
  return CheckType(typeInst) && CheckAnnotations(varInst) && CheckUses(varInst);
}

bool ScalarReplacementPass::CheckType(const Instruction* typeInst) const {
  if (!CheckTypeAnnotations(typeInst)) return false;

  switch (typeInst->opcode()) {
    case SpvOpTypeStruct:
      // Don't bother with empty structs or very large structs.
      if (typeInst->NumInOperands() == 0 ||
          IsLargerThanSizeLimit(typeInst->NumInOperands()))
        return false;
      return true;
    case SpvOpTypeArray:
      if (IsSpecConstant(typeInst->GetSingleWordInOperand(1u))) {
        return false;
      }
      if (IsLargerThanSizeLimit(GetArrayLength(typeInst))) {
        return false;
      }
      return true;
      // TODO(alanbaker): Develop some heuristics for when this should be
      // re-enabled.
      //// Specifically including matrix and vector in an attempt to reduce the
      //// number of vector registers required.
      // case SpvOpTypeMatrix:
      // case SpvOpTypeVector:
      //  if (IsLargerThanSizeLimit(GetNumElements(typeInst))) return false;
      //  return true;

    case SpvOpTypeRuntimeArray:
    default:
      return false;
  }
}

bool ScalarReplacementPass::CheckTypeAnnotations(
    const Instruction* typeInst) const {
  for (auto inst :
       get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
    uint32_t decoration;
    if (inst->opcode() == SpvOpDecorate) {
      decoration = inst->GetSingleWordInOperand(1u);
    } else {
      assert(inst->opcode() == SpvOpMemberDecorate);
      decoration = inst->GetSingleWordInOperand(2u);
    }

    switch (decoration) {
      case SpvDecorationRowMajor:
      case SpvDecorationColMajor:
      case SpvDecorationArrayStride:
      case SpvDecorationMatrixStride:
      case SpvDecorationCPacked:
      case SpvDecorationInvariant:
      case SpvDecorationRestrict:
      case SpvDecorationOffset:
      case SpvDecorationAlignment:
      case SpvDecorationAlignmentId:
      case SpvDecorationMaxByteOffset:
        break;
      default:
        return false;
    }
  }

  return true;
}

bool ScalarReplacementPass::CheckAnnotations(const Instruction* varInst) const {
  for (auto inst :
       get_decoration_mgr()->GetDecorationsFor(varInst->result_id(), false)) {
    assert(inst->opcode() == SpvOpDecorate);
    uint32_t decoration = inst->GetSingleWordInOperand(1u);
    switch (decoration) {
      case SpvDecorationInvariant:
      case SpvDecorationRestrict:
      case SpvDecorationAlignment:
      case SpvDecorationAlignmentId:
      case SpvDecorationMaxByteOffset:
        break;
      default:
        return false;
    }
  }

  return true;
}

bool ScalarReplacementPass::CheckUses(const Instruction* inst) const {
  VariableStats stats = {0, 0};
  bool ok = CheckUses(inst, &stats);

  // TODO(alanbaker/greg-lunarg): Add some meaningful heuristics about when
  // SRoA is costly, such as when the structure has many (unaccessed?)
  // members.

  return ok;
}

bool ScalarReplacementPass::CheckUses(const Instruction* inst,
                                      VariableStats* stats) const {
  bool ok = true;
  get_def_use_mgr()->ForEachUse(
      inst, [this, stats, &ok](const Instruction* user, uint32_t index) {
        // Annotations are check as a group separately.
        if (!IsAnnotationInst(user->opcode())) {
          switch (user->opcode()) {
            case SpvOpAccessChain:
            case SpvOpInBoundsAccessChain:
              if (index == 2u && user->NumInOperands() > 1) {
                uint32_t id = user->GetSingleWordInOperand(1u);
                const Instruction* opInst = get_def_use_mgr()->GetDef(id);
                if (!IsCompileTimeConstantInst(opInst->opcode())) {
                  ok = false;
                } else {
                  if (!CheckUsesRelaxed(user)) ok = false;
                }
                stats->num_partial_accesses++;
              } else {
                ok = false;
              }
              break;
            case SpvOpLoad:
              if (!CheckLoad(user, index)) ok = false;
              stats->num_full_accesses++;
              break;
            case SpvOpStore:
              if (!CheckStore(user, index)) ok = false;
              stats->num_full_accesses++;
              break;
            case SpvOpName:
            case SpvOpMemberName:
              break;
            default:
              ok = false;
              break;
          }
        }
      });

  return ok;
}

bool ScalarReplacementPass::CheckUsesRelaxed(const Instruction* inst) const {
  bool ok = true;
  get_def_use_mgr()->ForEachUse(
      inst, [this, &ok](const Instruction* user, uint32_t index) {
        switch (user->opcode()) {
          case SpvOpAccessChain:
          case SpvOpInBoundsAccessChain:
            if (index != 2u) {
              ok = false;
            } else {
              if (!CheckUsesRelaxed(user)) ok = false;
            }
            break;
          case SpvOpLoad:
            if (!CheckLoad(user, index)) ok = false;
            break;
          case SpvOpStore:
            if (!CheckStore(user, index)) ok = false;
            break;
          default:
            ok = false;
            break;
        }
      });

  return ok;
}

bool ScalarReplacementPass::CheckLoad(const Instruction* inst,
                                      uint32_t index) const {
  if (index != 2u) return false;
  if (inst->NumInOperands() >= 2 &&
      inst->GetSingleWordInOperand(1u) & SpvMemoryAccessVolatileMask)
    return false;
  return true;
}

bool ScalarReplacementPass::CheckStore(const Instruction* inst,
                                       uint32_t index) const {
  if (index != 0u) return false;
  if (inst->NumInOperands() >= 3 &&
      inst->GetSingleWordInOperand(2u) & SpvMemoryAccessVolatileMask)
    return false;
  return true;
}
bool ScalarReplacementPass::IsLargerThanSizeLimit(size_t length) const {
  if (max_num_elements_ == 0) {
    return false;
  }
  return length > max_num_elements_;
}

std::unique_ptr<std::unordered_set<uint64_t>>
ScalarReplacementPass::GetUsedComponents(Instruction* inst) {
  std::unique_ptr<std::unordered_set<uint64_t>> result(
      new std::unordered_set<uint64_t>());

  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();

  def_use_mgr->WhileEachUser(inst, [&result, def_use_mgr,
                                    this](Instruction* use) {
    switch (use->opcode()) {
      case SpvOpLoad: {
        // Look for extract from the load.
        std::vector<uint32_t> t;
        if (def_use_mgr->WhileEachUser(use, [&t](Instruction* use2) {
              if (use2->opcode() != SpvOpCompositeExtract) {
                return false;
              }
              t.push_back(use2->GetSingleWordInOperand(1));
              return true;
            })) {
          result->insert(t.begin(), t.end());
          return true;
        } else {
          result.reset(nullptr);
          return false;
        }
      }
      case SpvOpName:
      case SpvOpMemberName:
      case SpvOpStore:
        // No components are used.
        return true;
      case SpvOpAccessChain:
      case SpvOpInBoundsAccessChain: {
        // Add the first index it if is a constant.
        // TODO: Could be improved by checking if the address is used in a load.
        analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
        uint32_t index_id = use->GetSingleWordInOperand(1);
        const analysis::Constant* index_const =
            const_mgr->FindDeclaredConstant(index_id);
        if (index_const) {
          const analysis::Integer* index_type =
              index_const->type()->AsInteger();
          assert(index_type);
          if (index_type->width() == 32) {
            result->insert(index_const->GetU32());
            return true;
          } else if (index_type->width() == 64) {
            result->insert(index_const->GetU64());
            return true;
          }
          result.reset(nullptr);
          return false;
        } else {
          // Could be any element.  Assuming all are used.
          result.reset(nullptr);
          return false;
        }
      }
      default:
        // We do not know what is happening.  Have to assume the worst.
        result.reset(nullptr);
        return false;
    }
  });

  return result;
}

Instruction* ScalarReplacementPass::CreateNullConstant(uint32_t type_id) {
  analysis::TypeManager* type_mgr = context()->get_type_mgr();
  analysis::ConstantManager* const_mgr = context()->get_constant_mgr();

  const analysis::Type* type = type_mgr->GetType(type_id);
  const analysis::Constant* null_const = const_mgr->GetConstant(type, {});
  Instruction* null_inst =
      const_mgr->GetDefiningInstruction(null_const, type_id);
  context()->UpdateDefUse(null_inst);
  return null_inst;
}

}  // namespace opt
}  // namespace spvtools