// Copyright (c) 2018 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/opt/scalar_analysis.h" #include <algorithm> #include <functional> #include <string> #include <utility> #include "source/opt/ir_context.h" // Transforms a given scalar operation instruction into a DAG representation. // // 1. Take an instruction and traverse its operands until we reach a // constant node or an instruction which we do not know how to compute the // value, such as a load. // // 2. Create a new node for each instruction traversed and build the nodes for // the in operands of that instruction as well. // // 3. Add the operand nodes as children of the first and hash the node. Use the // hash to see if the node is already in the cache. We ensure the children are // always in sorted order so that two nodes with the same children but inserted // in a different order have the same hash and so that the overloaded operator== // will return true. If the node is already in the cache return the cached // version instead. // // 4. The created DAG can then be simplified by // ScalarAnalysis::SimplifyExpression, implemented in // scalar_analysis_simplification.cpp. See that file for further information on // the simplification process. // namespace spvtools { namespace opt { uint32_t SENode::NumberOfNodes = 0; ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context) : context_(context), pretend_equal_{} { // Create and cached the CantComputeNode. cached_cant_compute_ = GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this))); } SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) { // If operand is can't compute then the whole graph is can't compute. if (operand->IsCantCompute()) return CreateCantComputeNode(); if (operand->GetType() == SENode::Constant) { return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue()); } std::unique_ptr<SENode> negation_node{new SENegative(this)}; negation_node->AddChild(operand); return GetCachedOrAdd(std::move(negation_node)); } SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) { return GetCachedOrAdd( std::unique_ptr<SENode>(new SEConstantNode(this, integer))); } SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression( const Loop* loop, SENode* offset, SENode* coefficient) { assert(loop && "Recurrent add expressions must have a valid loop."); // If operands are can't compute then the whole graph is can't compute. if (offset->IsCantCompute() || coefficient->IsCantCompute()) return CreateCantComputeNode(); const Loop* loop_to_use = nullptr; if (pretend_equal_[loop]) { loop_to_use = pretend_equal_[loop]; } else { loop_to_use = loop; } std::unique_ptr<SERecurrentNode> phi_node{ new SERecurrentNode(this, loop_to_use)}; phi_node->AddOffset(offset); phi_node->AddCoefficient(coefficient); return GetCachedOrAdd(std::move(phi_node)); } SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp( const Instruction* multiply) { assert(multiply->opcode() == SpvOp::SpvOpIMul && "Multiply node did not come from a multiply instruction"); analysis::DefUseManager* def_use = context_->get_def_use_mgr(); SENode* op1 = AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0))); SENode* op2 = AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1))); return CreateMultiplyNode(op1, op2); } SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1, SENode* operand_2) { // If operands are can't compute then the whole graph is can't compute. if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) return CreateCantComputeNode(); if (operand_1->GetType() == SENode::Constant && operand_2->GetType() == SENode::Constant) { return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() * operand_2->AsSEConstantNode()->FoldToSingleValue()); } std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)}; multiply_node->AddChild(operand_1); multiply_node->AddChild(operand_2); return GetCachedOrAdd(std::move(multiply_node)); } SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1, SENode* operand_2) { // Fold if both operands are constant. if (operand_1->GetType() == SENode::Constant && operand_2->GetType() == SENode::Constant) { return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() - operand_2->AsSEConstantNode()->FoldToSingleValue()); } return CreateAddNode(operand_1, CreateNegation(operand_2)); } SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1, SENode* operand_2) { // Fold if both operands are constant and the |simplify| flag is true. if (operand_1->GetType() == SENode::Constant && operand_2->GetType() == SENode::Constant) { return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() + operand_2->AsSEConstantNode()->FoldToSingleValue()); } // If operands are can't compute then the whole graph is can't compute. if (operand_1->IsCantCompute() || operand_2->IsCantCompute()) return CreateCantComputeNode(); std::unique_ptr<SENode> add_node{new SEAddNode(this)}; add_node->AddChild(operand_1); add_node->AddChild(operand_2); return GetCachedOrAdd(std::move(add_node)); } SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) { auto itr = recurrent_node_map_.find(inst); if (itr != recurrent_node_map_.end()) return itr->second; SENode* output = nullptr; switch (inst->opcode()) { case SpvOp::SpvOpPhi: { output = AnalyzePhiInstruction(inst); break; } case SpvOp::SpvOpConstant: case SpvOp::SpvOpConstantNull: { output = AnalyzeConstant(inst); break; } case SpvOp::SpvOpISub: case SpvOp::SpvOpIAdd: { output = AnalyzeAddOp(inst); break; } case SpvOp::SpvOpIMul: { output = AnalyzeMultiplyOp(inst); break; } default: { output = CreateValueUnknownNode(inst); break; } } return output; } SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) { if (inst->opcode() == SpvOp::SpvOpConstantNull) return CreateConstant(0); assert(inst->opcode() == SpvOp::SpvOpConstant); assert(inst->NumInOperands() == 1); int64_t value = 0; // Look up the instruction in the constant manager. const analysis::Constant* constant = context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id()); if (!constant) return CreateCantComputeNode(); const analysis::IntConstant* int_constant = constant->AsIntConstant(); // Exit out if it is a 64 bit integer. if (!int_constant || int_constant->words().size() != 1) return CreateCantComputeNode(); if (int_constant->type()->AsInteger()->IsSigned()) { value = int_constant->GetS32BitValue(); } else { value = int_constant->GetU32BitValue(); } return CreateConstant(value); } // Handles both addition and subtraction. If the |sub| flag is set then the // addition will be op1+(-op2) otherwise op1+op2. SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) { assert((inst->opcode() == SpvOp::SpvOpIAdd || inst->opcode() == SpvOp::SpvOpISub) && "Add node must be created from a OpIAdd or OpISub instruction"); analysis::DefUseManager* def_use = context_->get_def_use_mgr(); SENode* op1 = AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0))); SENode* op2 = AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1))); // To handle subtraction we wrap the second operand in a unary negation node. if (inst->opcode() == SpvOp::SpvOpISub) { op2 = CreateNegation(op2); } return CreateAddNode(op1, op2); } SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) { // The phi should only have two incoming value pairs. if (phi->NumInOperands() != 4) { return CreateCantComputeNode(); } analysis::DefUseManager* def_use = context_->get_def_use_mgr(); // Get the basic block this instruction belongs to. BasicBlock* basic_block = context_->get_instr_block(const_cast<Instruction*>(phi)); // And then the function that the basic blocks belongs to. Function* function = basic_block->GetParent(); // Use the function to get the loop descriptor. LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function); // We only handle phis in loops at the moment. if (!loop_descriptor) return CreateCantComputeNode(); // Get the innermost loop which this block belongs to. Loop* loop = (*loop_descriptor)[basic_block->id()]; // If the loop doesn't exist or doesn't have a preheader or latch block, exit // out. if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() || loop->GetHeaderBlock() != basic_block) return recurrent_node_map_[phi] = CreateCantComputeNode(); const Loop* loop_to_use = nullptr; if (pretend_equal_[loop]) { loop_to_use = pretend_equal_[loop]; } else { loop_to_use = loop; } std::unique_ptr<SERecurrentNode> phi_node{ new SERecurrentNode(this, loop_to_use)}; // We add the node to this map to allow it to be returned before the node is // fully built. This is needed as the subsequent call to AnalyzeInstruction // could lead back to this |phi| instruction so we return the pointer // immediately in AnalyzeInstruction to break the recursion. recurrent_node_map_[phi] = phi_node.get(); // Traverse the operands of the instruction an create new nodes for each one. for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) { uint32_t value_id = phi->GetSingleWordInOperand(i); uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1); Instruction* value_inst = def_use->GetDef(value_id); SENode* value_node = AnalyzeInstruction(value_inst); // If any operand is CantCompute then the whole graph is CantCompute. if (value_node->IsCantCompute()) return recurrent_node_map_[phi] = CreateCantComputeNode(); // If the value is coming from the preheader block then the value is the // initial value of the phi. if (incoming_label_id == loop->GetPreHeaderBlock()->id()) { phi_node->AddOffset(value_node); } else if (incoming_label_id == loop->GetLatchBlock()->id()) { // Assumed to be in the form of step + phi. if (value_node->GetType() != SENode::Add) return recurrent_node_map_[phi] = CreateCantComputeNode(); SENode* step_node = nullptr; SENode* phi_operand = nullptr; SENode* operand_1 = value_node->GetChild(0); SENode* operand_2 = value_node->GetChild(1); // Find which node is the step term. if (!operand_1->AsSERecurrentNode()) step_node = operand_1; else if (!operand_2->AsSERecurrentNode()) step_node = operand_2; // Find which node is the recurrent expression. if (operand_1->AsSERecurrentNode()) phi_operand = operand_1; else if (operand_2->AsSERecurrentNode()) phi_operand = operand_2; // If it is not in the form step + phi exit out. if (!(step_node && phi_operand)) return recurrent_node_map_[phi] = CreateCantComputeNode(); // If the phi operand is not the same phi node exit out. if (phi_operand != phi_node.get()) return recurrent_node_map_[phi] = CreateCantComputeNode(); if (!IsLoopInvariant(loop, step_node)) return recurrent_node_map_[phi] = CreateCantComputeNode(); phi_node->AddCoefficient(step_node); } } // Once the node is fully built we update the map with the version from the // cache (if it has already been added to the cache). return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node)); } SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode( const Instruction* inst) { std::unique_ptr<SEValueUnknown> load_node{ new SEValueUnknown(this, inst->result_id())}; return GetCachedOrAdd(std::move(load_node)); } SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() { return cached_cant_compute_; } // Add the created node into the cache of nodes. If it already exists return it. SENode* ScalarEvolutionAnalysis::GetCachedOrAdd( std::unique_ptr<SENode> prospective_node) { auto itr = node_cache_.find(prospective_node); if (itr != node_cache_.end()) { return (*itr).get(); } SENode* raw_ptr_to_node = prospective_node.get(); node_cache_.insert(std::move(prospective_node)); return raw_ptr_to_node; } bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop, const SENode* node) const { for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) { if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) { const BasicBlock* header = rec->GetLoop()->GetHeaderBlock(); // If the loop which the recurrent expression belongs to is either |loop // or a nested loop inside |loop| then we assume it is variant. if (loop->IsInsideLoop(header)) { return false; } } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) { // If the instruction is inside the loop we conservatively assume it is // loop variant. if (loop->IsInsideLoop(unknown->ResultId())) return false; } } return true; } SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm( SENode* node, const Loop* loop) { // Traverse the DAG to find the recurrent expression belonging to |loop|. for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { SERecurrentNode* rec = itr->AsSERecurrentNode(); if (rec && rec->GetLoop() == loop) { return rec->GetCoefficient(); } } return CreateConstant(0); } SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent, SENode* old_child, SENode* new_child) { // Only handles add. if (parent->GetType() != SENode::Add) return parent; std::vector<SENode*> new_children; for (SENode* child : *parent) { if (child == old_child) { new_children.push_back(new_child); } else { new_children.push_back(child); } } std::unique_ptr<SENode> add_node{new SEAddNode(this)}; for (SENode* child : new_children) { add_node->AddChild(child); } return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); } // Rebuild the |node| eliminating, if it exists, the recurrent term which // belongs to the |loop|. SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm( SENode* node, const Loop* loop) { // If the node is already a recurrent expression belonging to loop then just // return the offset. SERecurrentNode* recurrent = node->AsSERecurrentNode(); if (recurrent) { if (recurrent->GetLoop() == loop) { return recurrent->GetOffset(); } else { return node; } } std::vector<SENode*> new_children; // Otherwise find the recurrent node in the children of this node. for (auto itr : *node) { recurrent = itr->AsSERecurrentNode(); if (recurrent && recurrent->GetLoop() == loop) { new_children.push_back(recurrent->GetOffset()); } else { new_children.push_back(itr); } } std::unique_ptr<SENode> add_node{new SEAddNode(this)}; for (SENode* child : new_children) { add_node->AddChild(child); } return SimplifyExpression(GetCachedOrAdd(std::move(add_node))); } // Return the recurrent term belonging to |loop| if it appears in the graph // starting at |node| or null if it doesn't. SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node, const Loop* loop) { for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) { SERecurrentNode* rec = itr->AsSERecurrentNode(); if (rec && rec->GetLoop() == loop) { return rec; } } return nullptr; } std::string SENode::AsString() const { switch (GetType()) { case Constant: return "Constant"; case RecurrentAddExpr: return "RecurrentAddExpr"; case Add: return "Add"; case Negative: return "Negative"; case Multiply: return "Multiply"; case ValueUnknown: return "Value Unknown"; case CanNotCompute: return "Can not compute"; } return "NULL"; } bool SENode::operator==(const SENode& other) const { if (GetType() != other.GetType()) return false; if (other.GetChildren().size() != children_.size()) return false; const SERecurrentNode* this_as_recurrent = AsSERecurrentNode(); // Check the children are the same, for SERecurrentNodes we need to check the // offset and coefficient manually as the child vector is sorted by ids so the // offset/coefficient information is lost. if (!this_as_recurrent) { for (size_t index = 0; index < children_.size(); ++index) { if (other.GetChildren()[index] != children_[index]) return false; } } else { const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode(); // We've already checked the types are the same, this should not fail if // this->AsSERecurrentNode() succeeded. assert(other_as_recurrent); if (this_as_recurrent->GetCoefficient() != other_as_recurrent->GetCoefficient()) return false; if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset()) return false; if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop()) return false; } // If we're dealing with a value unknown node check both nodes were created by // the same instruction. if (GetType() == SENode::ValueUnknown) { if (AsSEValueUnknown()->ResultId() != other.AsSEValueUnknown()->ResultId()) { return false; } } if (AsSEConstantNode()) { if (AsSEConstantNode()->FoldToSingleValue() != other.AsSEConstantNode()->FoldToSingleValue()) return false; } return true; } bool SENode::operator!=(const SENode& other) const { return !(*this == other); } namespace { // Helper functions to insert 32/64 bit values into the 32 bit hash string. This // allows us to add pointers to the string by reinterpreting the pointers as // uintptr_t. PushToString will deduce the type, call sizeof on it and use // that size to call into the correct PushToStringImpl functor depending on // whether it is 32 or 64 bit. template <typename T, size_t size_of_t> struct PushToStringImpl; template <typename T> struct PushToStringImpl<T, 8> { void operator()(T id, std::u32string* str) { str->push_back(static_cast<uint32_t>(id >> 32)); str->push_back(static_cast<uint32_t>(id)); } }; template <typename T> struct PushToStringImpl<T, 4> { void operator()(T id, std::u32string* str) { str->push_back(static_cast<uint32_t>(id)); } }; template <typename T> static void PushToString(T id, std::u32string* str) { PushToStringImpl<T, sizeof(T)>{}(id, str); } } // namespace // Implements the hashing of SENodes. size_t SENodeHash::operator()(const SENode* node) const { // Concatinate the terms into a string which we can hash. std::u32string hash_string{}; // Hashing the type as a string is safer than hashing the enum as the enum is // very likely to collide with constants. for (char ch : node->AsString()) { hash_string.push_back(static_cast<char32_t>(ch)); } // We just ignore the literal value unless it is a constant. if (node->GetType() == SENode::Constant) PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string); const SERecurrentNode* recurrent = node->AsSERecurrentNode(); // If we're dealing with a recurrent expression hash the loop as well so that // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes. if (recurrent) { PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()), &hash_string); // Recurrent expressions can't be hashed using the normal method as the // order of coefficient and offset matters to the hash. PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()), &hash_string); PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()), &hash_string); return std::hash<std::u32string>{}(hash_string); } // Hash the result id of the original instruction which created this node if // it is a value unknown node. if (node->GetType() == SENode::ValueUnknown) { PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string); } // Hash the pointers of the child nodes, each SENode has a unique pointer // associated with it. const std::vector<SENode*>& children = node->GetChildren(); for (const SENode* child : children) { PushToString(reinterpret_cast<uintptr_t>(child), &hash_string); } return std::hash<std::u32string>{}(hash_string); } // This overload is the actual overload used by the node_cache_ set. size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const { return this->operator()(node.get()); } void SENode::DumpDot(std::ostream& out, bool recurse) const { size_t unique_id = std::hash<const SENode*>{}(this); out << unique_id << " [label=\"" << AsString() << " "; if (GetType() == SENode::Constant) { out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue(); } out << "\"]\n"; for (const SENode* child : children_) { size_t child_unique_id = std::hash<const SENode*>{}(child); out << unique_id << " -> " << child_unique_id << " \n"; if (recurse) child->DumpDot(out, true); } } namespace { class IsGreaterThanZero { public: explicit IsGreaterThanZero(IRContext* context) : context_(context) {} // Determine if the value of |node| is always strictly greater than zero if // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is // true. It returns true is the evaluation was able to conclude something, in // which case the result is stored in |result|. // The algorithm work by going through all the nodes and determine the // sign of each of them. bool Eval(const SENode* node, bool or_equal_zero, bool* result) { *result = false; switch (Visit(node)) { case Signedness::kPositiveOrNegative: { return false; } case Signedness::kStrictlyNegative: { *result = false; break; } case Signedness::kNegative: { if (!or_equal_zero) { return false; } *result = false; break; } case Signedness::kStrictlyPositive: { *result = true; break; } case Signedness::kPositive: { if (!or_equal_zero) { return false; } *result = true; break; } } return true; } private: enum class Signedness { kPositiveOrNegative, // Yield a value positive or negative. kStrictlyNegative, // Yield a value strictly less than 0. kNegative, // Yield a value less or equal to 0. kStrictlyPositive, // Yield a value strictly greater than 0. kPositive // Yield a value greater or equal to 0. }; // Combine the signedness according to arithmetic rules of a given operator. using Combiner = std::function<Signedness(Signedness, Signedness)>; // Returns a functor to interpret the signedness of 2 expressions as if they // were added. Combiner GetAddCombiner() const { return [](Signedness lhs, Signedness rhs) { switch (lhs) { case Signedness::kPositiveOrNegative: break; case Signedness::kStrictlyNegative: if (rhs == Signedness::kStrictlyNegative || rhs == Signedness::kNegative) return lhs; break; case Signedness::kNegative: { if (rhs == Signedness::kStrictlyNegative) return Signedness::kStrictlyNegative; if (rhs == Signedness::kNegative) return Signedness::kNegative; break; } case Signedness::kStrictlyPositive: { if (rhs == Signedness::kStrictlyPositive || rhs == Signedness::kPositive) { return Signedness::kStrictlyPositive; } break; } case Signedness::kPositive: { if (rhs == Signedness::kStrictlyPositive) return Signedness::kStrictlyPositive; if (rhs == Signedness::kPositive) return Signedness::kPositive; break; } } return Signedness::kPositiveOrNegative; }; } // Returns a functor to interpret the signedness of 2 expressions as if they // were multiplied. Combiner GetMulCombiner() const { return [](Signedness lhs, Signedness rhs) { switch (lhs) { case Signedness::kPositiveOrNegative: break; case Signedness::kStrictlyNegative: { switch (rhs) { case Signedness::kPositiveOrNegative: { break; } case Signedness::kStrictlyNegative: { return Signedness::kStrictlyPositive; } case Signedness::kNegative: { return Signedness::kPositive; } case Signedness::kStrictlyPositive: { return Signedness::kStrictlyNegative; } case Signedness::kPositive: { return Signedness::kNegative; } } break; } case Signedness::kNegative: { switch (rhs) { case Signedness::kPositiveOrNegative: { break; } case Signedness::kStrictlyNegative: case Signedness::kNegative: { return Signedness::kPositive; } case Signedness::kStrictlyPositive: case Signedness::kPositive: { return Signedness::kNegative; } } break; } case Signedness::kStrictlyPositive: { return rhs; } case Signedness::kPositive: { switch (rhs) { case Signedness::kPositiveOrNegative: { break; } case Signedness::kStrictlyNegative: case Signedness::kNegative: { return Signedness::kNegative; } case Signedness::kStrictlyPositive: case Signedness::kPositive: { return Signedness::kPositive; } } break; } } return Signedness::kPositiveOrNegative; }; } Signedness Visit(const SENode* node) { switch (node->GetType()) { case SENode::Constant: return Visit(node->AsSEConstantNode()); break; case SENode::RecurrentAddExpr: return Visit(node->AsSERecurrentNode()); break; case SENode::Negative: return Visit(node->AsSENegative()); break; case SENode::CanNotCompute: return Visit(node->AsSECantCompute()); break; case SENode::ValueUnknown: return Visit(node->AsSEValueUnknown()); break; case SENode::Add: return VisitExpr(node, GetAddCombiner()); break; case SENode::Multiply: return VisitExpr(node, GetMulCombiner()); break; } return Signedness::kPositiveOrNegative; } // Returns the signedness of a constant |node|. Signedness Visit(const SEConstantNode* node) { if (0 == node->FoldToSingleValue()) return Signedness::kPositive; if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive; if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative; return Signedness::kPositiveOrNegative; } // Returns the signedness of an unknown |node| based on its type. Signedness Visit(const SEValueUnknown* node) { Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId()); analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id()); assert(type && "Can't retrieve a type for the instruction"); analysis::Integer* int_type = type->AsInteger(); assert(type && "Can't retrieve an integer type for the instruction"); return int_type->IsSigned() ? Signedness::kPositiveOrNegative : Signedness::kPositive; } // Returns the signedness of a recurring expression. Signedness Visit(const SERecurrentNode* node) { Signedness coeff_sign = Visit(node->GetCoefficient()); // SERecurrentNode represent an affine expression in the range [0, // loop_bound], so the result cannot be strictly positive or negative. switch (coeff_sign) { default: break; case Signedness::kStrictlyNegative: coeff_sign = Signedness::kNegative; break; case Signedness::kStrictlyPositive: coeff_sign = Signedness::kPositive; break; } return GetAddCombiner()(coeff_sign, Visit(node->GetOffset())); } // Returns the signedness of a negation |node|. Signedness Visit(const SENegative* node) { switch (Visit(*node->begin())) { case Signedness::kPositiveOrNegative: { return Signedness::kPositiveOrNegative; } case Signedness::kStrictlyNegative: { return Signedness::kStrictlyPositive; } case Signedness::kNegative: { return Signedness::kPositive; } case Signedness::kStrictlyPositive: { return Signedness::kStrictlyNegative; } case Signedness::kPositive: { return Signedness::kNegative; } } return Signedness::kPositiveOrNegative; } Signedness Visit(const SECantCompute*) { return Signedness::kPositiveOrNegative; } // Returns the signedness of a binary expression by using the combiner // |reduce|. Signedness VisitExpr( const SENode* node, std::function<Signedness(Signedness, Signedness)> reduce) { Signedness result = Visit(*node->begin()); for (const SENode* operand : make_range(++node->begin(), node->end())) { if (result == Signedness::kPositiveOrNegative) { return Signedness::kPositiveOrNegative; } result = reduce(result, Visit(operand)); } return result; } IRContext* context_; }; } // namespace bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node, bool* is_gt_zero) const { return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero); } bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero( SENode* node, bool* is_ge_zero) const { return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero); } namespace { // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z), // if |node| is not in the chain, returns the original chain. static SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul, const SENode* node) { SENode* lhs = mul->GetChildren()[0]; SENode* rhs = mul->GetChildren()[1]; if (lhs == node) { return rhs; } if (rhs == node) { return lhs; } if (lhs->AsSEMultiplyNode()) { SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node); if (res != lhs) return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); } if (rhs->AsSEMultiplyNode()) { SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node); if (res != rhs) return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs); } return mul; } } // namespace std::pair<SExpression, int64_t> SExpression::operator/( SExpression rhs_wrapper) const { SENode* lhs = node_; SENode* rhs = rhs_wrapper.node_; // Check for division by 0. if (rhs->AsSEConstantNode() && !rhs->AsSEConstantNode()->FoldToSingleValue()) { return {scev_->CreateCantComputeNode(), 0}; } // Trivial case. if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) { int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue(); int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue(); return {scev_->CreateConstant(lhs_value / rhs_value), lhs_value % rhs_value}; } // look for a "c U / U" pattern. if (lhs->AsSEMultiplyNode()) { assert(lhs->GetChildren().size() == 2 && "More than 2 operand for a multiply node."); SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs); if (res != lhs) { return {res, 0}; } } return {scev_->CreateCantComputeNode(), 0}; } } // namespace opt } // namespace spvtools