// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/scalar_analysis.h"

#include <algorithm>
#include <functional>
#include <string>
#include <utility>

#include "source/opt/ir_context.h"

// Transforms a given scalar operation instruction into a DAG representation.
//
// 1. Take an instruction and traverse its operands until we reach a
// constant node or an instruction which we do not know how to compute the
// value, such as a load.
//
// 2. Create a new node for each instruction traversed and build the nodes for
// the in operands of that instruction as well.
//
// 3. Add the operand nodes as children of the first and hash the node. Use the
// hash to see if the node is already in the cache. We ensure the children are
// always in sorted order so that two nodes with the same children but inserted
// in a different order have the same hash and so that the overloaded operator==
// will return true. If the node is already in the cache return the cached
// version instead.
//
// 4. The created DAG can then be simplified by
// ScalarAnalysis::SimplifyExpression, implemented in
// scalar_analysis_simplification.cpp. See that file for further information on
// the simplification process.
//

namespace spvtools {
namespace opt {

uint32_t SENode::NumberOfNodes = 0;

ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
    : context_(context), pretend_equal_{} {
  // Create and cached the CantComputeNode.
  cached_cant_compute_ =
      GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
}

SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
  // If operand is can't compute then the whole graph is can't compute.
  if (operand->IsCantCompute()) return CreateCantComputeNode();

  if (operand->GetType() == SENode::Constant) {
    return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
  }
  std::unique_ptr<SENode> negation_node{new SENegative(this)};
  negation_node->AddChild(operand);
  return GetCachedOrAdd(std::move(negation_node));
}

SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
  return GetCachedOrAdd(
      std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
}

SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
    const Loop* loop, SENode* offset, SENode* coefficient) {
  assert(loop && "Recurrent add expressions must have a valid loop.");

  // If operands are can't compute then the whole graph is can't compute.
  if (offset->IsCantCompute() || coefficient->IsCantCompute())
    return CreateCantComputeNode();

  const Loop* loop_to_use = nullptr;
  if (pretend_equal_[loop]) {
    loop_to_use = pretend_equal_[loop];
  } else {
    loop_to_use = loop;
  }

  std::unique_ptr<SERecurrentNode> phi_node{
      new SERecurrentNode(this, loop_to_use)};
  phi_node->AddOffset(offset);
  phi_node->AddCoefficient(coefficient);

  return GetCachedOrAdd(std::move(phi_node));
}

SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
    const Instruction* multiply) {
  assert(multiply->opcode() == SpvOp::SpvOpIMul &&
         "Multiply node did not come from a multiply instruction");
  analysis::DefUseManager* def_use = context_->get_def_use_mgr();

  SENode* op1 =
      AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
  SENode* op2 =
      AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));

  return CreateMultiplyNode(op1, op2);
}

SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
                                                    SENode* operand_2) {
  // If operands are can't compute then the whole graph is can't compute.
  if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
    return CreateCantComputeNode();

  if (operand_1->GetType() == SENode::Constant &&
      operand_2->GetType() == SENode::Constant) {
    return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
                          operand_2->AsSEConstantNode()->FoldToSingleValue());
  }

  std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};

  multiply_node->AddChild(operand_1);
  multiply_node->AddChild(operand_2);

  return GetCachedOrAdd(std::move(multiply_node));
}

SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
                                                   SENode* operand_2) {
  // Fold if both operands are constant.
  if (operand_1->GetType() == SENode::Constant &&
      operand_2->GetType() == SENode::Constant) {
    return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
                          operand_2->AsSEConstantNode()->FoldToSingleValue());
  }

  return CreateAddNode(operand_1, CreateNegation(operand_2));
}

SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
                                               SENode* operand_2) {
  // Fold if both operands are constant and the |simplify| flag is true.
  if (operand_1->GetType() == SENode::Constant &&
      operand_2->GetType() == SENode::Constant) {
    return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
                          operand_2->AsSEConstantNode()->FoldToSingleValue());
  }

  // If operands are can't compute then the whole graph is can't compute.
  if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
    return CreateCantComputeNode();

  std::unique_ptr<SENode> add_node{new SEAddNode(this)};

  add_node->AddChild(operand_1);
  add_node->AddChild(operand_2);

  return GetCachedOrAdd(std::move(add_node));
}

SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
  auto itr = recurrent_node_map_.find(inst);
  if (itr != recurrent_node_map_.end()) return itr->second;

  SENode* output = nullptr;
  switch (inst->opcode()) {
    case SpvOp::SpvOpPhi: {
      output = AnalyzePhiInstruction(inst);
      break;
    }
    case SpvOp::SpvOpConstant:
    case SpvOp::SpvOpConstantNull: {
      output = AnalyzeConstant(inst);
      break;
    }
    case SpvOp::SpvOpISub:
    case SpvOp::SpvOpIAdd: {
      output = AnalyzeAddOp(inst);
      break;
    }
    case SpvOp::SpvOpIMul: {
      output = AnalyzeMultiplyOp(inst);
      break;
    }
    default: {
      output = CreateValueUnknownNode(inst);
      break;
    }
  }

  return output;
}

SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
  if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0);

  assert(inst->opcode() == SpvOp::SpvOpConstant);
  assert(inst->NumInOperands() == 1);
  int64_t value = 0;

  // Look up the instruction in the constant manager.
  const analysis::Constant* constant =
      context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());

  if (!constant) return CreateCantComputeNode();

  const analysis::IntConstant* int_constant = constant->AsIntConstant();

  // Exit out if it is a 64 bit integer.
  if (!int_constant || int_constant->words().size() != 1)
    return CreateCantComputeNode();

  if (int_constant->type()->AsInteger()->IsSigned()) {
    value = int_constant->GetS32BitValue();
  } else {
    value = int_constant->GetU32BitValue();
  }

  return CreateConstant(value);
}

// Handles both addition and subtraction. If the |sub| flag is set then the
// addition will be op1+(-op2) otherwise op1+op2.
SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
  assert((inst->opcode() == SpvOp::SpvOpIAdd ||
          inst->opcode() == SpvOp::SpvOpISub) &&
         "Add node must be created from a OpIAdd or OpISub instruction");

  analysis::DefUseManager* def_use = context_->get_def_use_mgr();

  SENode* op1 =
      AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));

  SENode* op2 =
      AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));

  // To handle subtraction we wrap the second operand in a unary negation node.
  if (inst->opcode() == SpvOp::SpvOpISub) {
    op2 = CreateNegation(op2);
  }

  return CreateAddNode(op1, op2);
}

SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
  // The phi should only have two incoming value pairs.
  if (phi->NumInOperands() != 4) {
    return CreateCantComputeNode();
  }

  analysis::DefUseManager* def_use = context_->get_def_use_mgr();

  // Get the basic block this instruction belongs to.
  BasicBlock* basic_block =
      context_->get_instr_block(const_cast<Instruction*>(phi));

  // And then the function that the basic blocks belongs to.
  Function* function = basic_block->GetParent();

  // Use the function to get the loop descriptor.
  LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);

  // We only handle phis in loops at the moment.
  if (!loop_descriptor) return CreateCantComputeNode();

  // Get the innermost loop which this block belongs to.
  Loop* loop = (*loop_descriptor)[basic_block->id()];

  // If the loop doesn't exist or doesn't have a preheader or latch block, exit
  // out.
  if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
      loop->GetHeaderBlock() != basic_block)
    return recurrent_node_map_[phi] = CreateCantComputeNode();

  const Loop* loop_to_use = nullptr;
  if (pretend_equal_[loop]) {
    loop_to_use = pretend_equal_[loop];
  } else {
    loop_to_use = loop;
  }
  std::unique_ptr<SERecurrentNode> phi_node{
      new SERecurrentNode(this, loop_to_use)};

  // We add the node to this map to allow it to be returned before the node is
  // fully built. This is needed as the subsequent call to AnalyzeInstruction
  // could lead back to this |phi| instruction so we return the pointer
  // immediately in AnalyzeInstruction to break the recursion.
  recurrent_node_map_[phi] = phi_node.get();

  // Traverse the operands of the instruction an create new nodes for each one.
  for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
    uint32_t value_id = phi->GetSingleWordInOperand(i);
    uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);

    Instruction* value_inst = def_use->GetDef(value_id);
    SENode* value_node = AnalyzeInstruction(value_inst);

    // If any operand is CantCompute then the whole graph is CantCompute.
    if (value_node->IsCantCompute())
      return recurrent_node_map_[phi] = CreateCantComputeNode();

    // If the value is coming from the preheader block then the value is the
    // initial value of the phi.
    if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
      phi_node->AddOffset(value_node);
    } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
      // Assumed to be in the form of step + phi.
      if (value_node->GetType() != SENode::Add)
        return recurrent_node_map_[phi] = CreateCantComputeNode();

      SENode* step_node = nullptr;
      SENode* phi_operand = nullptr;
      SENode* operand_1 = value_node->GetChild(0);
      SENode* operand_2 = value_node->GetChild(1);

      // Find which node is the step term.
      if (!operand_1->AsSERecurrentNode())
        step_node = operand_1;
      else if (!operand_2->AsSERecurrentNode())
        step_node = operand_2;

      // Find which node is the recurrent expression.
      if (operand_1->AsSERecurrentNode())
        phi_operand = operand_1;
      else if (operand_2->AsSERecurrentNode())
        phi_operand = operand_2;

      // If it is not in the form step + phi exit out.
      if (!(step_node && phi_operand))
        return recurrent_node_map_[phi] = CreateCantComputeNode();

      // If the phi operand is not the same phi node exit out.
      if (phi_operand != phi_node.get())
        return recurrent_node_map_[phi] = CreateCantComputeNode();

      if (!IsLoopInvariant(loop, step_node))
        return recurrent_node_map_[phi] = CreateCantComputeNode();

      phi_node->AddCoefficient(step_node);
    }
  }

  // Once the node is fully built we update the map with the version from the
  // cache (if it has already been added to the cache).
  return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
}

SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
    const Instruction* inst) {
  std::unique_ptr<SEValueUnknown> load_node{
      new SEValueUnknown(this, inst->result_id())};
  return GetCachedOrAdd(std::move(load_node));
}

SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
  return cached_cant_compute_;
}

// Add the created node into the cache of nodes. If it already exists return it.
SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
    std::unique_ptr<SENode> prospective_node) {
  auto itr = node_cache_.find(prospective_node);
  if (itr != node_cache_.end()) {
    return (*itr).get();
  }

  SENode* raw_ptr_to_node = prospective_node.get();
  node_cache_.insert(std::move(prospective_node));
  return raw_ptr_to_node;
}

bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
                                              const SENode* node) const {
  for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
    if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
      const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();

      // If the loop which the recurrent expression belongs to is either |loop
      // or a nested loop inside |loop| then we assume it is variant.
      if (loop->IsInsideLoop(header)) {
        return false;
      }
    } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
      // If the instruction is inside the loop we conservatively assume it is
      // loop variant.
      if (loop->IsInsideLoop(unknown->ResultId())) return false;
    }
  }

  return true;
}

SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
    SENode* node, const Loop* loop) {
  // Traverse the DAG to find the recurrent expression belonging to |loop|.
  for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
    SERecurrentNode* rec = itr->AsSERecurrentNode();
    if (rec && rec->GetLoop() == loop) {
      return rec->GetCoefficient();
    }
  }
  return CreateConstant(0);
}

SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
                                                 SENode* old_child,
                                                 SENode* new_child) {
  // Only handles add.
  if (parent->GetType() != SENode::Add) return parent;

  std::vector<SENode*> new_children;
  for (SENode* child : *parent) {
    if (child == old_child) {
      new_children.push_back(new_child);
    } else {
      new_children.push_back(child);
    }
  }

  std::unique_ptr<SENode> add_node{new SEAddNode(this)};
  for (SENode* child : new_children) {
    add_node->AddChild(child);
  }

  return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
}

// Rebuild the |node| eliminating, if it exists, the recurrent term which
// belongs to the |loop|.
SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
    SENode* node, const Loop* loop) {
  // If the node is already a recurrent expression belonging to loop then just
  // return the offset.
  SERecurrentNode* recurrent = node->AsSERecurrentNode();
  if (recurrent) {
    if (recurrent->GetLoop() == loop) {
      return recurrent->GetOffset();
    } else {
      return node;
    }
  }

  std::vector<SENode*> new_children;
  // Otherwise find the recurrent node in the children of this node.
  for (auto itr : *node) {
    recurrent = itr->AsSERecurrentNode();
    if (recurrent && recurrent->GetLoop() == loop) {
      new_children.push_back(recurrent->GetOffset());
    } else {
      new_children.push_back(itr);
    }
  }

  std::unique_ptr<SENode> add_node{new SEAddNode(this)};
  for (SENode* child : new_children) {
    add_node->AddChild(child);
  }

  return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
}

// Return the recurrent term belonging to |loop| if it appears in the graph
// starting at |node| or null if it doesn't.
SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
                                                           const Loop* loop) {
  for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
    SERecurrentNode* rec = itr->AsSERecurrentNode();
    if (rec && rec->GetLoop() == loop) {
      return rec;
    }
  }
  return nullptr;
}
std::string SENode::AsString() const {
  switch (GetType()) {
    case Constant:
      return "Constant";
    case RecurrentAddExpr:
      return "RecurrentAddExpr";
    case Add:
      return "Add";
    case Negative:
      return "Negative";
    case Multiply:
      return "Multiply";
    case ValueUnknown:
      return "Value Unknown";
    case CanNotCompute:
      return "Can not compute";
  }
  return "NULL";
}

bool SENode::operator==(const SENode& other) const {
  if (GetType() != other.GetType()) return false;

  if (other.GetChildren().size() != children_.size()) return false;

  const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();

  // Check the children are the same, for SERecurrentNodes we need to check the
  // offset and coefficient manually as the child vector is sorted by ids so the
  // offset/coefficient information is lost.
  if (!this_as_recurrent) {
    for (size_t index = 0; index < children_.size(); ++index) {
      if (other.GetChildren()[index] != children_[index]) return false;
    }
  } else {
    const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();

    // We've already checked the types are the same, this should not fail if
    // this->AsSERecurrentNode() succeeded.
    assert(other_as_recurrent);

    if (this_as_recurrent->GetCoefficient() !=
        other_as_recurrent->GetCoefficient())
      return false;

    if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
      return false;

    if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
      return false;
  }

  // If we're dealing with a value unknown node check both nodes were created by
  // the same instruction.
  if (GetType() == SENode::ValueUnknown) {
    if (AsSEValueUnknown()->ResultId() !=
        other.AsSEValueUnknown()->ResultId()) {
      return false;
    }
  }

  if (AsSEConstantNode()) {
    if (AsSEConstantNode()->FoldToSingleValue() !=
        other.AsSEConstantNode()->FoldToSingleValue())
      return false;
  }

  return true;
}

bool SENode::operator!=(const SENode& other) const { return !(*this == other); }

namespace {
// Helper functions to insert 32/64 bit values into the 32 bit hash string. This
// allows us to add pointers to the string by reinterpreting the pointers as
// uintptr_t. PushToString will deduce the type, call sizeof on it and use
// that size to call into the correct PushToStringImpl functor depending on
// whether it is 32 or 64 bit.

template <typename T, size_t size_of_t>
struct PushToStringImpl;

template <typename T>
struct PushToStringImpl<T, 8> {
  void operator()(T id, std::u32string* str) {
    str->push_back(static_cast<uint32_t>(id >> 32));
    str->push_back(static_cast<uint32_t>(id));
  }
};

template <typename T>
struct PushToStringImpl<T, 4> {
  void operator()(T id, std::u32string* str) {
    str->push_back(static_cast<uint32_t>(id));
  }
};

template <typename T>
static void PushToString(T id, std::u32string* str) {
  PushToStringImpl<T, sizeof(T)>{}(id, str);
}

}  // namespace

// Implements the hashing of SENodes.
size_t SENodeHash::operator()(const SENode* node) const {
  // Concatinate the terms into a string which we can hash.
  std::u32string hash_string{};

  // Hashing the type as a string is safer than hashing the enum as the enum is
  // very likely to collide with constants.
  for (char ch : node->AsString()) {
    hash_string.push_back(static_cast<char32_t>(ch));
  }

  // We just ignore the literal value unless it is a constant.
  if (node->GetType() == SENode::Constant)
    PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);

  const SERecurrentNode* recurrent = node->AsSERecurrentNode();

  // If we're dealing with a recurrent expression hash the loop as well so that
  // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
  if (recurrent) {
    PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
                 &hash_string);

    // Recurrent expressions can't be hashed using the normal method as the
    // order of coefficient and offset matters to the hash.
    PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
                 &hash_string);
    PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
                 &hash_string);

    return std::hash<std::u32string>{}(hash_string);
  }

  // Hash the result id of the original instruction which created this node if
  // it is a value unknown node.
  if (node->GetType() == SENode::ValueUnknown) {
    PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
  }

  // Hash the pointers of the child nodes, each SENode has a unique pointer
  // associated with it.
  const std::vector<SENode*>& children = node->GetChildren();
  for (const SENode* child : children) {
    PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
  }

  return std::hash<std::u32string>{}(hash_string);
}

// This overload is the actual overload used by the node_cache_ set.
size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
  return this->operator()(node.get());
}

void SENode::DumpDot(std::ostream& out, bool recurse) const {
  size_t unique_id = std::hash<const SENode*>{}(this);
  out << unique_id << " [label=\"" << AsString() << " ";
  if (GetType() == SENode::Constant) {
    out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
  }
  out << "\"]\n";
  for (const SENode* child : children_) {
    size_t child_unique_id = std::hash<const SENode*>{}(child);
    out << unique_id << " -> " << child_unique_id << " \n";
    if (recurse) child->DumpDot(out, true);
  }
}

namespace {
class IsGreaterThanZero {
 public:
  explicit IsGreaterThanZero(IRContext* context) : context_(context) {}

  // Determine if the value of |node| is always strictly greater than zero if
  // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
  // true. It returns true is the evaluation was able to conclude something, in
  // which case the result is stored in |result|.
  // The algorithm work by going through all the nodes and determine the
  // sign of each of them.
  bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
    *result = false;
    switch (Visit(node)) {
      case Signedness::kPositiveOrNegative: {
        return false;
      }
      case Signedness::kStrictlyNegative: {
        *result = false;
        break;
      }
      case Signedness::kNegative: {
        if (!or_equal_zero) {
          return false;
        }
        *result = false;
        break;
      }
      case Signedness::kStrictlyPositive: {
        *result = true;
        break;
      }
      case Signedness::kPositive: {
        if (!or_equal_zero) {
          return false;
        }
        *result = true;
        break;
      }
    }
    return true;
  }

 private:
  enum class Signedness {
    kPositiveOrNegative,  // Yield a value positive or negative.
    kStrictlyNegative,    // Yield a value strictly less than 0.
    kNegative,            // Yield a value less or equal to 0.
    kStrictlyPositive,    // Yield a value strictly greater than 0.
    kPositive             // Yield a value greater or equal to 0.
  };

  // Combine the signedness according to arithmetic rules of a given operator.
  using Combiner = std::function<Signedness(Signedness, Signedness)>;

  // Returns a functor to interpret the signedness of 2 expressions as if they
  // were added.
  Combiner GetAddCombiner() const {
    return [](Signedness lhs, Signedness rhs) {
      switch (lhs) {
        case Signedness::kPositiveOrNegative:
          break;
        case Signedness::kStrictlyNegative:
          if (rhs == Signedness::kStrictlyNegative ||
              rhs == Signedness::kNegative)
            return lhs;
          break;
        case Signedness::kNegative: {
          if (rhs == Signedness::kStrictlyNegative)
            return Signedness::kStrictlyNegative;
          if (rhs == Signedness::kNegative) return Signedness::kNegative;
          break;
        }
        case Signedness::kStrictlyPositive: {
          if (rhs == Signedness::kStrictlyPositive ||
              rhs == Signedness::kPositive) {
            return Signedness::kStrictlyPositive;
          }
          break;
        }
        case Signedness::kPositive: {
          if (rhs == Signedness::kStrictlyPositive)
            return Signedness::kStrictlyPositive;
          if (rhs == Signedness::kPositive) return Signedness::kPositive;
          break;
        }
      }
      return Signedness::kPositiveOrNegative;
    };
  }

  // Returns a functor to interpret the signedness of 2 expressions as if they
  // were multiplied.
  Combiner GetMulCombiner() const {
    return [](Signedness lhs, Signedness rhs) {
      switch (lhs) {
        case Signedness::kPositiveOrNegative:
          break;
        case Signedness::kStrictlyNegative: {
          switch (rhs) {
            case Signedness::kPositiveOrNegative: {
              break;
            }
            case Signedness::kStrictlyNegative: {
              return Signedness::kStrictlyPositive;
            }
            case Signedness::kNegative: {
              return Signedness::kPositive;
            }
            case Signedness::kStrictlyPositive: {
              return Signedness::kStrictlyNegative;
            }
            case Signedness::kPositive: {
              return Signedness::kNegative;
            }
          }
          break;
        }
        case Signedness::kNegative: {
          switch (rhs) {
            case Signedness::kPositiveOrNegative: {
              break;
            }
            case Signedness::kStrictlyNegative:
            case Signedness::kNegative: {
              return Signedness::kPositive;
            }
            case Signedness::kStrictlyPositive:
            case Signedness::kPositive: {
              return Signedness::kNegative;
            }
          }
          break;
        }
        case Signedness::kStrictlyPositive: {
          return rhs;
        }
        case Signedness::kPositive: {
          switch (rhs) {
            case Signedness::kPositiveOrNegative: {
              break;
            }
            case Signedness::kStrictlyNegative:
            case Signedness::kNegative: {
              return Signedness::kNegative;
            }
            case Signedness::kStrictlyPositive:
            case Signedness::kPositive: {
              return Signedness::kPositive;
            }
          }
          break;
        }
      }
      return Signedness::kPositiveOrNegative;
    };
  }

  Signedness Visit(const SENode* node) {
    switch (node->GetType()) {
      case SENode::Constant:
        return Visit(node->AsSEConstantNode());
        break;
      case SENode::RecurrentAddExpr:
        return Visit(node->AsSERecurrentNode());
        break;
      case SENode::Negative:
        return Visit(node->AsSENegative());
        break;
      case SENode::CanNotCompute:
        return Visit(node->AsSECantCompute());
        break;
      case SENode::ValueUnknown:
        return Visit(node->AsSEValueUnknown());
        break;
      case SENode::Add:
        return VisitExpr(node, GetAddCombiner());
        break;
      case SENode::Multiply:
        return VisitExpr(node, GetMulCombiner());
        break;
    }
    return Signedness::kPositiveOrNegative;
  }

  // Returns the signedness of a constant |node|.
  Signedness Visit(const SEConstantNode* node) {
    if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
    if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
    if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
    return Signedness::kPositiveOrNegative;
  }

  // Returns the signedness of an unknown |node| based on its type.
  Signedness Visit(const SEValueUnknown* node) {
    Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
    analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
    assert(type && "Can't retrieve a type for the instruction");
    analysis::Integer* int_type = type->AsInteger();
    assert(type && "Can't retrieve an integer type for the instruction");
    return int_type->IsSigned() ? Signedness::kPositiveOrNegative
                                : Signedness::kPositive;
  }

  // Returns the signedness of a recurring expression.
  Signedness Visit(const SERecurrentNode* node) {
    Signedness coeff_sign = Visit(node->GetCoefficient());
    // SERecurrentNode represent an affine expression in the range [0,
    // loop_bound], so the result cannot be strictly positive or negative.
    switch (coeff_sign) {
      default:
        break;
      case Signedness::kStrictlyNegative:
        coeff_sign = Signedness::kNegative;
        break;
      case Signedness::kStrictlyPositive:
        coeff_sign = Signedness::kPositive;
        break;
    }
    return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
  }

  // Returns the signedness of a negation |node|.
  Signedness Visit(const SENegative* node) {
    switch (Visit(*node->begin())) {
      case Signedness::kPositiveOrNegative: {
        return Signedness::kPositiveOrNegative;
      }
      case Signedness::kStrictlyNegative: {
        return Signedness::kStrictlyPositive;
      }
      case Signedness::kNegative: {
        return Signedness::kPositive;
      }
      case Signedness::kStrictlyPositive: {
        return Signedness::kStrictlyNegative;
      }
      case Signedness::kPositive: {
        return Signedness::kNegative;
      }
    }
    return Signedness::kPositiveOrNegative;
  }

  Signedness Visit(const SECantCompute*) {
    return Signedness::kPositiveOrNegative;
  }

  // Returns the signedness of a binary expression by using the combiner
  // |reduce|.
  Signedness VisitExpr(
      const SENode* node,
      std::function<Signedness(Signedness, Signedness)> reduce) {
    Signedness result = Visit(*node->begin());
    for (const SENode* operand : make_range(++node->begin(), node->end())) {
      if (result == Signedness::kPositiveOrNegative) {
        return Signedness::kPositiveOrNegative;
      }
      result = reduce(result, Visit(operand));
    }
    return result;
  }

  IRContext* context_;
};
}  // namespace

bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
                                                      bool* is_gt_zero) const {
  return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
}

bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
    SENode* node, bool* is_ge_zero) const {
  return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
}

namespace {

// Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
// if |node| is not in the chain, returns the original chain.
static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
                                              const SENode* node) {
  SENode* lhs = mul->GetChildren()[0];
  SENode* rhs = mul->GetChildren()[1];
  if (lhs == node) {
    return rhs;
  }
  if (rhs == node) {
    return lhs;
  }
  if (lhs->AsSEMultiplyNode()) {
    SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
    if (res != lhs)
      return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
  }
  if (rhs->AsSEMultiplyNode()) {
    SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
    if (res != rhs)
      return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
  }

  return mul;
}
}  // namespace

std::pair<SExpression, int64_t> SExpression::operator/(
    SExpression rhs_wrapper) const {
  SENode* lhs = node_;
  SENode* rhs = rhs_wrapper.node_;
  // Check for division by 0.
  if (rhs->AsSEConstantNode() &&
      !rhs->AsSEConstantNode()->FoldToSingleValue()) {
    return {scev_->CreateCantComputeNode(), 0};
  }

  // Trivial case.
  if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
    int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
    int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
    return {scev_->CreateConstant(lhs_value / rhs_value),
            lhs_value % rhs_value};
  }

  // look for a "c U / U" pattern.
  if (lhs->AsSEMultiplyNode()) {
    assert(lhs->GetChildren().size() == 2 &&
           "More than 2 operand for a multiply node.");
    SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
    if (res != lhs) {
      return {res, 0};
    }
  }

  return {scev_->CreateCantComputeNode(), 0};
}

}  // namespace opt
}  // namespace spvtools