// Copyright (c) 2018 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/val/validate.h" #include "source/opcode.h" #include "source/val/instruction.h" #include "source/val/validation_state.h" namespace spvtools { namespace val { namespace { spv_result_t ValidateConstantBool(ValidationState_t& _, const Instruction* inst) { auto type = _.FindDef(inst->type_id()); if (!type || type->opcode() != SpvOpTypeBool) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '" << _.getIdName(inst->type_id()) << "' is not a boolean type."; } return SPV_SUCCESS; } spv_result_t ValidateConstantComposite(ValidationState_t& _, const Instruction* inst) { std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode()); const auto result_type = _.FindDef(inst->type_id()); if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Result Type <id> '" << _.getIdName(inst->type_id()) << "' is not a composite type."; } const auto constituent_count = inst->words().size() - 3; switch (result_type->opcode()) { case SpvOpTypeVector: { const auto component_count = result_type->GetOperandAs<uint32_t>(2); if (component_count != constituent_count) { // TODO: Output ID's on diagnostic return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> count does not match " "Result Type <id> '" << _.getIdName(result_type->id()) << "'s vector component count."; } const auto component_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1)); if (!component_type) { return _.diag(SPV_ERROR_INVALID_ID, result_type) << "Component type is not defined."; } for (size_t constituent_index = 2; constituent_index < inst->operands().size(); constituent_index++) { const auto constituent_id = inst->GetOperandAs<uint32_t>(constituent_index); const auto constituent = _.FindDef(constituent_id); if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' is not a constant or undef."; } const auto constituent_result_type = _.FindDef(constituent->type_id()); if (!constituent_result_type || component_type->opcode() != constituent_result_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "'s type does not match Result Type <id> '" << _.getIdName(result_type->id()) << "'s vector element type."; } } } break; case SpvOpTypeMatrix: { const auto column_count = result_type->GetOperandAs<uint32_t>(2); if (column_count != constituent_count) { // TODO: Output ID's on diagnostic return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> count does not match " "Result Type <id> '" << _.getIdName(result_type->id()) << "'s matrix column count."; } const auto column_type = _.FindDef(result_type->words()[2]); if (!column_type) { return _.diag(SPV_ERROR_INVALID_ID, result_type) << "Column type is not defined."; } const auto component_count = column_type->GetOperandAs<uint32_t>(2); const auto component_type = _.FindDef(column_type->GetOperandAs<uint32_t>(1)); if (!component_type) { return _.diag(SPV_ERROR_INVALID_ID, column_type) << "Component type is not defined."; } for (size_t constituent_index = 2; constituent_index < inst->operands().size(); constituent_index++) { const auto constituent_id = inst->GetOperandAs<uint32_t>(constituent_index); const auto constituent = _.FindDef(constituent_id); if (!constituent || !(SpvOpConstantComposite == constituent->opcode() || SpvOpSpecConstantComposite == constituent->opcode() || SpvOpUndef == constituent->opcode())) { // The message says "... or undef" because the spec does not say // undef is a constant. return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' is not a constant composite or undef."; } const auto vector = _.FindDef(constituent->type_id()); if (!vector) { return _.diag(SPV_ERROR_INVALID_ID, constituent) << "Result type is not defined."; } if (column_type->opcode() != vector->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' type does not match Result Type <id> '" << _.getIdName(result_type->id()) << "'s matrix column type."; } const auto vector_component_type = _.FindDef(vector->GetOperandAs<uint32_t>(1)); if (component_type->id() != vector_component_type->id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' component type does not match Result Type <id> '" << _.getIdName(result_type->id()) << "'s matrix column component type."; } if (component_count != vector->words()[3]) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' vector component count does not match Result Type <id> '" << _.getIdName(result_type->id()) << "'s vector component count."; } } } break; case SpvOpTypeArray: { auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1)); if (!element_type) { return _.diag(SPV_ERROR_INVALID_ID, result_type) << "Element type is not defined."; } const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); if (!length) { return _.diag(SPV_ERROR_INVALID_ID, result_type) << "Length is not defined."; } bool is_int32; bool is_const; uint32_t value; std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); if (is_int32 && is_const && value != constituent_count) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent count does not match " "Result Type <id> '" << _.getIdName(result_type->id()) << "'s array length."; } for (size_t constituent_index = 2; constituent_index < inst->operands().size(); constituent_index++) { const auto constituent_id = inst->GetOperandAs<uint32_t>(constituent_index); const auto constituent = _.FindDef(constituent_id); if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' is not a constant or undef."; } const auto constituent_type = _.FindDef(constituent->type_id()); if (!constituent_type) { return _.diag(SPV_ERROR_INVALID_ID, constituent) << "Result type is not defined."; } if (element_type->id() != constituent_type->id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "'s type does not match Result Type <id> '" << _.getIdName(result_type->id()) << "'s array element type."; } } } break; case SpvOpTypeStruct: { const auto member_count = result_type->words().size() - 2; if (member_count != constituent_count) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(inst->type_id()) << "' count does not match Result Type <id> '" << _.getIdName(result_type->id()) << "'s struct member count."; } for (uint32_t constituent_index = 2, member_index = 1; constituent_index < inst->operands().size(); constituent_index++, member_index++) { const auto constituent_id = inst->GetOperandAs<uint32_t>(constituent_index); const auto constituent = _.FindDef(constituent_id); if (!constituent || !spvOpcodeIsConstantOrUndef(constituent->opcode())) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' is not a constant or undef."; } const auto constituent_type = _.FindDef(constituent->type_id()); if (!constituent_type) { return _.diag(SPV_ERROR_INVALID_ID, constituent) << "Result type is not defined."; } const auto member_type_id = result_type->GetOperandAs<uint32_t>(member_index); const auto member_type = _.FindDef(member_type_id); if (!member_type || member_type->id() != constituent_type->id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) << opcode_name << " Constituent <id> '" << _.getIdName(constituent_id) << "' type does not match the Result Type <id> '" << _.getIdName(result_type->id()) << "'s member type."; } } } break; default: break; } return SPV_SUCCESS; } spv_result_t ValidateConstantSampler(ValidationState_t& _, const Instruction* inst) { const auto result_type = _.FindDef(inst->type_id()); if (!result_type || result_type->opcode() != SpvOpTypeSampler) { return _.diag(SPV_ERROR_INVALID_ID, result_type) << "OpConstantSampler Result Type <id> '" << _.getIdName(inst->type_id()) << "' is not a sampler type."; } return SPV_SUCCESS; } // True if instruction defines a type that can have a null value, as defined by // the SPIR-V spec. Tracks composite-type components through module to check // nullability transitively. bool IsTypeNullable(const std::vector<uint32_t>& instruction, const ValidationState_t& _) { uint16_t opcode; uint16_t word_count; spvOpcodeSplit(instruction[0], &word_count, &opcode); switch (static_cast<SpvOp>(opcode)) { case SpvOpTypeBool: case SpvOpTypeInt: case SpvOpTypeFloat: case SpvOpTypePointer: case SpvOpTypeEvent: case SpvOpTypeDeviceEvent: case SpvOpTypeReserveId: case SpvOpTypeQueue: return true; case SpvOpTypeArray: case SpvOpTypeMatrix: case SpvOpTypeVector: { auto base_type = _.FindDef(instruction[2]); return base_type && IsTypeNullable(base_type->words(), _); } case SpvOpTypeStruct: { for (size_t elementIndex = 2; elementIndex < instruction.size(); ++elementIndex) { auto element = _.FindDef(instruction[elementIndex]); if (!element || !IsTypeNullable(element->words(), _)) return false; } return true; } default: return false; } } spv_result_t ValidateConstantNull(ValidationState_t& _, const Instruction* inst) { const auto result_type = _.FindDef(inst->type_id()); if (!result_type || !IsTypeNullable(result_type->words(), _)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpConstantNull Result Type <id> '" << _.getIdName(inst->type_id()) << "' cannot have a null value."; } return SPV_SUCCESS; } spv_result_t ValidateSpecConstantOp(ValidationState_t& _, const Instruction* inst) { const auto op = inst->GetOperandAs<SpvOp>(2); // The binary parser already ensures that the op is valid for *some* // environment. Here we check restrictions. switch(op) { case SpvOpQuantizeToF16: if (!_.HasCapability(SpvCapabilityShader)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Specialization constant operation " << spvOpcodeString(op) << " requires Shader capability"; } break; case SpvOpUConvert: if (!_.features().uconvert_spec_constant_op && !_.HasCapability(SpvCapabilityKernel)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "UConvert requires Kernel capability or extension " "SPV_AMD_gpu_shader_int16"; } break; case SpvOpConvertFToS: case SpvOpConvertSToF: case SpvOpConvertFToU: case SpvOpConvertUToF: case SpvOpConvertPtrToU: case SpvOpConvertUToPtr: case SpvOpGenericCastToPtr: case SpvOpPtrCastToGeneric: case SpvOpBitcast: case SpvOpFNegate: case SpvOpFAdd: case SpvOpFSub: case SpvOpFMul: case SpvOpFDiv: case SpvOpFRem: case SpvOpFMod: case SpvOpAccessChain: case SpvOpInBoundsAccessChain: case SpvOpPtrAccessChain: case SpvOpInBoundsPtrAccessChain: if (!_.HasCapability(SpvCapabilityKernel)) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Specialization constant operation " << spvOpcodeString(op) << " requires Kernel capability"; } break; default: break; } // TODO(dneto): Validate result type and arguments to the various operations. return SPV_SUCCESS; } } // namespace spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { switch (inst->opcode()) { case SpvOpConstantTrue: case SpvOpConstantFalse: case SpvOpSpecConstantTrue: case SpvOpSpecConstantFalse: if (auto error = ValidateConstantBool(_, inst)) return error; break; case SpvOpConstantComposite: case SpvOpSpecConstantComposite: if (auto error = ValidateConstantComposite(_, inst)) return error; break; case SpvOpConstantSampler: if (auto error = ValidateConstantSampler(_, inst)) return error; break; case SpvOpConstantNull: if (auto error = ValidateConstantNull(_, inst)) return error; break; case SpvOpSpecConstantOp: if (auto error = ValidateSpecConstantOp(_, inst)) return error; break; default: break; } return SPV_SUCCESS; } } // namespace val } // namespace spvtools