C++程序  |  707行  |  22.14 KB

// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/fold.h"

#include <cassert>
#include <cstdint>
#include <vector>

#include "source/opt/const_folding_rules.h"
#include "source/opt/def_use_manager.h"
#include "source/opt/folding_rules.h"
#include "source/opt/ir_builder.h"
#include "source/opt/ir_context.h"

namespace spvtools {
namespace opt {
namespace {

#ifndef INT32_MIN
#define INT32_MIN (-2147483648)
#endif

#ifndef INT32_MAX
#define INT32_MAX 2147483647
#endif

#ifndef UINT32_MAX
#define UINT32_MAX 0xffffffff /* 4294967295U */
#endif

}  // namespace

uint32_t InstructionFolder::UnaryOperate(SpvOp opcode, uint32_t operand) const {
  switch (opcode) {
    // Arthimetics
    case SpvOp::SpvOpSNegate: {
      int32_t s_operand = static_cast<int32_t>(operand);
      if (s_operand == std::numeric_limits<int32_t>::min()) {
        return s_operand;
      }
      return -s_operand;
    }
    case SpvOp::SpvOpNot:
      return ~operand;
    case SpvOp::SpvOpLogicalNot:
      return !static_cast<bool>(operand);
    default:
      assert(false &&
             "Unsupported unary operation for OpSpecConstantOp instruction");
      return 0u;
  }
}

uint32_t InstructionFolder::BinaryOperate(SpvOp opcode, uint32_t a,
                                          uint32_t b) const {
  switch (opcode) {
    // Arthimetics
    case SpvOp::SpvOpIAdd:
      return a + b;
    case SpvOp::SpvOpISub:
      return a - b;
    case SpvOp::SpvOpIMul:
      return a * b;
    case SpvOp::SpvOpUDiv:
      if (b != 0) {
        return a / b;
      } else {
        // Dividing by 0 is undefined, so we will just pick 0.
        return 0;
      }
    case SpvOp::SpvOpSDiv:
      if (b != 0u) {
        return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
      } else {
        // Dividing by 0 is undefined, so we will just pick 0.
        return 0;
      }
    case SpvOp::SpvOpSRem: {
      // The sign of non-zero result comes from the first operand: a. This is
      // guaranteed by C++11 rules for integer division operator. The division
      // result is rounded toward zero, so the result of '%' has the sign of
      // the first operand.
      if (b != 0u) {
        return static_cast<int32_t>(a) % static_cast<int32_t>(b);
      } else {
        // Remainder when dividing with 0 is undefined, so we will just pick 0.
        return 0;
      }
    }
    case SpvOp::SpvOpSMod: {
      // The sign of non-zero result comes from the second operand: b
      if (b != 0u) {
        int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
        int32_t b_prim = static_cast<int32_t>(b);
        return (rem + b_prim) % b_prim;
      } else {
        // Mod with 0 is undefined, so we will just pick 0.
        return 0;
      }
    }
    case SpvOp::SpvOpUMod:
      if (b != 0u) {
        return (a % b);
      } else {
        // Mod with 0 is undefined, so we will just pick 0.
        return 0;
      }

    // Shifting
    case SpvOp::SpvOpShiftRightLogical:
      if (b >= 32) {
        // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
        // When |b| == 32, doing the shift in C++ in undefined, but the result
        // will be 0, so just return that value.
        return 0;
      }
      return a >> b;
    case SpvOp::SpvOpShiftRightArithmetic:
      if (b > 32) {
        // This is undefined behaviour.  Choose 0 for consistency.
        return 0;
      }
      if (b == 32) {
        // Doing the shift in C++ is undefined, but the result is defined in the
        // spir-v spec.  Find that value another way.
        if (static_cast<int32_t>(a) >= 0) {
          return 0;
        } else {
          return static_cast<uint32_t>(-1);
        }
      }
      return (static_cast<int32_t>(a)) >> b;
    case SpvOp::SpvOpShiftLeftLogical:
      if (b >= 32) {
        // This is undefined behaviour when |b| > 32.  Choose 0 for consistency.
        // When |b| == 32, doing the shift in C++ in undefined, but the result
        // will be 0, so just return that value.
        return 0;
      }
      return a << b;

    // Bitwise operations
    case SpvOp::SpvOpBitwiseOr:
      return a | b;
    case SpvOp::SpvOpBitwiseAnd:
      return a & b;
    case SpvOp::SpvOpBitwiseXor:
      return a ^ b;

    // Logical
    case SpvOp::SpvOpLogicalEqual:
      return (static_cast<bool>(a)) == (static_cast<bool>(b));
    case SpvOp::SpvOpLogicalNotEqual:
      return (static_cast<bool>(a)) != (static_cast<bool>(b));
    case SpvOp::SpvOpLogicalOr:
      return (static_cast<bool>(a)) || (static_cast<bool>(b));
    case SpvOp::SpvOpLogicalAnd:
      return (static_cast<bool>(a)) && (static_cast<bool>(b));

    // Comparison
    case SpvOp::SpvOpIEqual:
      return a == b;
    case SpvOp::SpvOpINotEqual:
      return a != b;
    case SpvOp::SpvOpULessThan:
      return a < b;
    case SpvOp::SpvOpSLessThan:
      return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
    case SpvOp::SpvOpUGreaterThan:
      return a > b;
    case SpvOp::SpvOpSGreaterThan:
      return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
    case SpvOp::SpvOpULessThanEqual:
      return a <= b;
    case SpvOp::SpvOpSLessThanEqual:
      return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
    case SpvOp::SpvOpUGreaterThanEqual:
      return a >= b;
    case SpvOp::SpvOpSGreaterThanEqual:
      return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
    default:
      assert(false &&
             "Unsupported binary operation for OpSpecConstantOp instruction");
      return 0u;
  }
}

uint32_t InstructionFolder::TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b,
                                           uint32_t c) const {
  switch (opcode) {
    case SpvOp::SpvOpSelect:
      return (static_cast<bool>(a)) ? b : c;
    default:
      assert(false &&
             "Unsupported ternary operation for OpSpecConstantOp instruction");
      return 0u;
  }
}

uint32_t InstructionFolder::OperateWords(
    SpvOp opcode, const std::vector<uint32_t>& operand_words) const {
  switch (operand_words.size()) {
    case 1:
      return UnaryOperate(opcode, operand_words.front());
    case 2:
      return BinaryOperate(opcode, operand_words.front(), operand_words.back());
    case 3:
      return TernaryOperate(opcode, operand_words[0], operand_words[1],
                            operand_words[2]);
    default:
      assert(false && "Invalid number of operands");
      return 0;
  }
}

bool InstructionFolder::FoldInstructionInternal(Instruction* inst) const {
  auto identity_map = [](uint32_t id) { return id; };
  Instruction* folded_inst = FoldInstructionToConstant(inst, identity_map);
  if (folded_inst != nullptr) {
    inst->SetOpcode(SpvOpCopyObject);
    inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {folded_inst->result_id()}}});
    return true;
  }

  SpvOp opcode = inst->opcode();
  analysis::ConstantManager* const_manager = context_->get_constant_mgr();

  std::vector<const analysis::Constant*> constants =
      const_manager->GetOperandConstants(inst);

  for (const FoldingRule& rule : GetFoldingRules().GetRulesForOpcode(opcode)) {
    if (rule(context_, inst, constants)) {
      return true;
    }
  }
  return false;
}

// Returns the result of performing an operation on scalar constant operands.
// This function extracts the operand values as 32 bit words and returns the
// result in 32 bit word. Scalar constants with longer than 32-bit width are
// not accepted in this function.
uint32_t InstructionFolder::FoldScalars(
    SpvOp opcode,
    const std::vector<const analysis::Constant*>& operands) const {
  assert(IsFoldableOpcode(opcode) &&
         "Unhandled instruction opcode in FoldScalars");
  std::vector<uint32_t> operand_values_in_raw_words;
  for (const auto& operand : operands) {
    if (const analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
      const auto& scalar_words = scalar->words();
      assert(scalar_words.size() == 1 &&
             "Scalar constants with longer than 32-bit width are not allowed "
             "in FoldScalars()");
      operand_values_in_raw_words.push_back(scalar_words.front());
    } else if (operand->AsNullConstant()) {
      operand_values_in_raw_words.push_back(0u);
    } else {
      assert(false &&
             "FoldScalars() only accepts ScalarConst or NullConst type of "
             "constant");
    }
  }
  return OperateWords(opcode, operand_values_in_raw_words);
}

bool InstructionFolder::FoldBinaryIntegerOpToConstant(
    Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
    uint32_t* result) const {
  SpvOp opcode = inst->opcode();
  analysis::ConstantManager* const_manger = context_->get_constant_mgr();

  uint32_t ids[2];
  const analysis::IntConstant* constants[2];
  for (uint32_t i = 0; i < 2; i++) {
    const Operand* operand = &inst->GetInOperand(i);
    if (operand->type != SPV_OPERAND_TYPE_ID) {
      return false;
    }
    ids[i] = id_map(operand->words[0]);
    const analysis::Constant* constant =
        const_manger->FindDeclaredConstant(ids[i]);
    constants[i] = (constant != nullptr ? constant->AsIntConstant() : nullptr);
  }

  switch (opcode) {
    // Arthimetics
    case SpvOp::SpvOpIMul:
      for (uint32_t i = 0; i < 2; i++) {
        if (constants[i] != nullptr && constants[i]->IsZero()) {
          *result = 0;
          return true;
        }
      }
      break;
    case SpvOp::SpvOpUDiv:
    case SpvOp::SpvOpSDiv:
    case SpvOp::SpvOpSRem:
    case SpvOp::SpvOpSMod:
    case SpvOp::SpvOpUMod:
      // This changes undefined behaviour (ie divide by 0) into a 0.
      for (uint32_t i = 0; i < 2; i++) {
        if (constants[i] != nullptr && constants[i]->IsZero()) {
          *result = 0;
          return true;
        }
      }
      break;

    // Shifting
    case SpvOp::SpvOpShiftRightLogical:
    case SpvOp::SpvOpShiftLeftLogical:
      if (constants[1] != nullptr) {
        // When shifting by a value larger than the size of the result, the
        // result is undefined.  We are setting the undefined behaviour to a
        // result of 0.  If the shift amount is the same as the size of the
        // result, then the result is defined, and it 0.
        uint32_t shift_amount = constants[1]->GetU32BitValue();
        if (shift_amount >= 32) {
          *result = 0;
          return true;
        }
      }
      break;

    // Bitwise operations
    case SpvOp::SpvOpBitwiseOr:
      for (uint32_t i = 0; i < 2; i++) {
        if (constants[i] != nullptr) {
          // TODO: Change the mask against a value based on the bit width of the
          // instruction result type.  This way we can handle say 16-bit values
          // as well.
          uint32_t mask = constants[i]->GetU32BitValue();
          if (mask == 0xFFFFFFFF) {
            *result = 0xFFFFFFFF;
            return true;
          }
        }
      }
      break;
    case SpvOp::SpvOpBitwiseAnd:
      for (uint32_t i = 0; i < 2; i++) {
        if (constants[i] != nullptr) {
          if (constants[i]->IsZero()) {
            *result = 0;
            return true;
          }
        }
      }
      break;

    // Comparison
    case SpvOp::SpvOpULessThan:
      if (constants[0] != nullptr &&
          constants[0]->GetU32BitValue() == UINT32_MAX) {
        *result = false;
        return true;
      }
      if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
        *result = false;
        return true;
      }
      break;
    case SpvOp::SpvOpSLessThan:
      if (constants[0] != nullptr &&
          constants[0]->GetS32BitValue() == INT32_MAX) {
        *result = false;
        return true;
      }
      if (constants[1] != nullptr &&
          constants[1]->GetS32BitValue() == INT32_MIN) {
        *result = false;
        return true;
      }
      break;
    case SpvOp::SpvOpUGreaterThan:
      if (constants[0] != nullptr && constants[0]->IsZero()) {
        *result = false;
        return true;
      }
      if (constants[1] != nullptr &&
          constants[1]->GetU32BitValue() == UINT32_MAX) {
        *result = false;
        return true;
      }
      break;
    case SpvOp::SpvOpSGreaterThan:
      if (constants[0] != nullptr &&
          constants[0]->GetS32BitValue() == INT32_MIN) {
        *result = false;
        return true;
      }
      if (constants[1] != nullptr &&
          constants[1]->GetS32BitValue() == INT32_MAX) {
        *result = false;
        return true;
      }
      break;
    case SpvOp::SpvOpULessThanEqual:
      if (constants[0] != nullptr && constants[0]->IsZero()) {
        *result = true;
        return true;
      }
      if (constants[1] != nullptr &&
          constants[1]->GetU32BitValue() == UINT32_MAX) {
        *result = true;
        return true;
      }
      break;
    case SpvOp::SpvOpSLessThanEqual:
      if (constants[0] != nullptr &&
          constants[0]->GetS32BitValue() == INT32_MIN) {
        *result = true;
        return true;
      }
      if (constants[1] != nullptr &&
          constants[1]->GetS32BitValue() == INT32_MAX) {
        *result = true;
        return true;
      }
      break;
    case SpvOp::SpvOpUGreaterThanEqual:
      if (constants[0] != nullptr &&
          constants[0]->GetU32BitValue() == UINT32_MAX) {
        *result = true;
        return true;
      }
      if (constants[1] != nullptr && constants[1]->GetU32BitValue() == 0) {
        *result = true;
        return true;
      }
      break;
    case SpvOp::SpvOpSGreaterThanEqual:
      if (constants[0] != nullptr &&
          constants[0]->GetS32BitValue() == INT32_MAX) {
        *result = true;
        return true;
      }
      if (constants[1] != nullptr &&
          constants[1]->GetS32BitValue() == INT32_MIN) {
        *result = true;
        return true;
      }
      break;
    default:
      break;
  }
  return false;
}

bool InstructionFolder::FoldBinaryBooleanOpToConstant(
    Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
    uint32_t* result) const {
  SpvOp opcode = inst->opcode();
  analysis::ConstantManager* const_manger = context_->get_constant_mgr();

  uint32_t ids[2];
  const analysis::BoolConstant* constants[2];
  for (uint32_t i = 0; i < 2; i++) {
    const Operand* operand = &inst->GetInOperand(i);
    if (operand->type != SPV_OPERAND_TYPE_ID) {
      return false;
    }
    ids[i] = id_map(operand->words[0]);
    const analysis::Constant* constant =
        const_manger->FindDeclaredConstant(ids[i]);
    constants[i] = (constant != nullptr ? constant->AsBoolConstant() : nullptr);
  }

  switch (opcode) {
    // Logical
    case SpvOp::SpvOpLogicalOr:
      for (uint32_t i = 0; i < 2; i++) {
        if (constants[i] != nullptr) {
          if (constants[i]->value()) {
            *result = true;
            return true;
          }
        }
      }
      break;
    case SpvOp::SpvOpLogicalAnd:
      for (uint32_t i = 0; i < 2; i++) {
        if (constants[i] != nullptr) {
          if (!constants[i]->value()) {
            *result = false;
            return true;
          }
        }
      }
      break;

    default:
      break;
  }
  return false;
}

bool InstructionFolder::FoldIntegerOpToConstant(
    Instruction* inst, const std::function<uint32_t(uint32_t)>& id_map,
    uint32_t* result) const {
  assert(IsFoldableOpcode(inst->opcode()) &&
         "Unhandled instruction opcode in FoldScalars");
  switch (inst->NumInOperands()) {
    case 2:
      return FoldBinaryIntegerOpToConstant(inst, id_map, result) ||
             FoldBinaryBooleanOpToConstant(inst, id_map, result);
    default:
      return false;
  }
}

std::vector<uint32_t> InstructionFolder::FoldVectors(
    SpvOp opcode, uint32_t num_dims,
    const std::vector<const analysis::Constant*>& operands) const {
  assert(IsFoldableOpcode(opcode) &&
         "Unhandled instruction opcode in FoldVectors");
  std::vector<uint32_t> result;
  for (uint32_t d = 0; d < num_dims; d++) {
    std::vector<uint32_t> operand_values_for_one_dimension;
    for (const auto& operand : operands) {
      if (const analysis::VectorConstant* vector_operand =
              operand->AsVectorConstant()) {
        // Extract the raw value of the scalar component constants
        // in 32-bit words here. The reason of not using FoldScalars() here
        // is that we do not create temporary null constants as components
        // when the vector operand is a NullConstant because Constant creation
        // may need extra checks for the validity and that is not manageed in
        // here.
        if (const analysis::ScalarConstant* scalar_component =
                vector_operand->GetComponents().at(d)->AsScalarConstant()) {
          const auto& scalar_words = scalar_component->words();
          assert(
              scalar_words.size() == 1 &&
              "Vector components with longer than 32-bit width are not allowed "
              "in FoldVectors()");
          operand_values_for_one_dimension.push_back(scalar_words.front());
        } else if (operand->AsNullConstant()) {
          operand_values_for_one_dimension.push_back(0u);
        } else {
          assert(false &&
                 "VectorConst should only has ScalarConst or NullConst as "
                 "components");
        }
      } else if (operand->AsNullConstant()) {
        operand_values_for_one_dimension.push_back(0u);
      } else {
        assert(false &&
               "FoldVectors() only accepts VectorConst or NullConst type of "
               "constant");
      }
    }
    result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
  }
  return result;
}

bool InstructionFolder::IsFoldableOpcode(SpvOp opcode) const {
  // NOTE: Extend to more opcodes as new cases are handled in the folder
  // functions.
  switch (opcode) {
    case SpvOp::SpvOpBitwiseAnd:
    case SpvOp::SpvOpBitwiseOr:
    case SpvOp::SpvOpBitwiseXor:
    case SpvOp::SpvOpIAdd:
    case SpvOp::SpvOpIEqual:
    case SpvOp::SpvOpIMul:
    case SpvOp::SpvOpINotEqual:
    case SpvOp::SpvOpISub:
    case SpvOp::SpvOpLogicalAnd:
    case SpvOp::SpvOpLogicalEqual:
    case SpvOp::SpvOpLogicalNot:
    case SpvOp::SpvOpLogicalNotEqual:
    case SpvOp::SpvOpLogicalOr:
    case SpvOp::SpvOpNot:
    case SpvOp::SpvOpSDiv:
    case SpvOp::SpvOpSelect:
    case SpvOp::SpvOpSGreaterThan:
    case SpvOp::SpvOpSGreaterThanEqual:
    case SpvOp::SpvOpShiftLeftLogical:
    case SpvOp::SpvOpShiftRightArithmetic:
    case SpvOp::SpvOpShiftRightLogical:
    case SpvOp::SpvOpSLessThan:
    case SpvOp::SpvOpSLessThanEqual:
    case SpvOp::SpvOpSMod:
    case SpvOp::SpvOpSNegate:
    case SpvOp::SpvOpSRem:
    case SpvOp::SpvOpUDiv:
    case SpvOp::SpvOpUGreaterThan:
    case SpvOp::SpvOpUGreaterThanEqual:
    case SpvOp::SpvOpULessThan:
    case SpvOp::SpvOpULessThanEqual:
    case SpvOp::SpvOpUMod:
      return true;
    default:
      return false;
  }
}

bool InstructionFolder::IsFoldableConstant(
    const analysis::Constant* cst) const {
  // Currently supported constants are 32-bit values or null constants.
  if (const analysis::ScalarConstant* scalar = cst->AsScalarConstant())
    return scalar->words().size() == 1;
  else
    return cst->AsNullConstant() != nullptr;
}

Instruction* InstructionFolder::FoldInstructionToConstant(
    Instruction* inst, std::function<uint32_t(uint32_t)> id_map) const {
  analysis::ConstantManager* const_mgr = context_->get_constant_mgr();

  if (!inst->IsFoldableByFoldScalar() &&
      !GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
    return nullptr;
  }
  // Collect the values of the constant parameters.
  std::vector<const analysis::Constant*> constants;
  bool missing_constants = false;
  inst->ForEachInId([&constants, &missing_constants, const_mgr,
                     &id_map](uint32_t* op_id) {
    uint32_t id = id_map(*op_id);
    const analysis::Constant* const_op = const_mgr->FindDeclaredConstant(id);
    if (!const_op) {
      constants.push_back(nullptr);
      missing_constants = true;
    } else {
      constants.push_back(const_op);
    }
  });

  if (GetConstantFoldingRules().HasFoldingRule(inst->opcode())) {
    const analysis::Constant* folded_const = nullptr;
    for (auto rule :
         GetConstantFoldingRules().GetRulesForOpcode(inst->opcode())) {
      folded_const = rule(context_, inst, constants);
      if (folded_const != nullptr) {
        Instruction* const_inst =
            const_mgr->GetDefiningInstruction(folded_const, inst->type_id());
        assert(const_inst->type_id() == inst->type_id());
        // May be a new instruction that needs to be analysed.
        context_->UpdateDefUse(const_inst);
        return const_inst;
      }
    }
  }

  uint32_t result_val = 0;
  bool successful = false;
  // If all parameters are constant, fold the instruction to a constant.
  if (!missing_constants && inst->IsFoldableByFoldScalar()) {
    result_val = FoldScalars(inst->opcode(), constants);
    successful = true;
  }

  if (!successful && inst->IsFoldableByFoldScalar()) {
    successful = FoldIntegerOpToConstant(inst, id_map, &result_val);
  }

  if (successful) {
    const analysis::Constant* result_const =
        const_mgr->GetConstant(const_mgr->GetType(inst), {result_val});
    Instruction* folded_inst =
        const_mgr->GetDefiningInstruction(result_const, inst->type_id());
    return folded_inst;
  }
  return nullptr;
}

bool InstructionFolder::IsFoldableType(Instruction* type_inst) const {
  // Support 32-bit integers.
  if (type_inst->opcode() == SpvOpTypeInt) {
    return type_inst->GetSingleWordInOperand(0) == 32;
  }
  // Support booleans.
  if (type_inst->opcode() == SpvOpTypeBool) {
    return true;
  }
  // Nothing else yet.
  return false;
}

bool InstructionFolder::FoldInstruction(Instruction* inst) const {
  bool modified = false;
  Instruction* folded_inst(inst);
  while (folded_inst->opcode() != SpvOpCopyObject &&
         FoldInstructionInternal(&*folded_inst)) {
    modified = true;
  }
  return modified;
}

}  // namespace opt
}  // namespace spvtools