// Copyright (c) 2018 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/comp/markv_encoder.h" #include "source/binary.h" #include "source/opcode.h" #include "spirv-tools/libspirv.hpp" namespace spvtools { namespace comp { namespace { const size_t kCommentNumWhitespaces = 2; } // namespace spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) { auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); if (codec) { uint64_t bits = 0; size_t num_bits = 0; if (codec->Encode(word, &bits, &num_bits)) { // Encoding successful. writer_.WriteBits(bits, num_bits); return SPV_SUCCESS; } else { // Encoding failed, write kMarkvNoneOfTheAbove flag. if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) return Diag(SPV_ERROR_INTERNAL) << "Non-id word Huffman table for " << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " << operand_index_ << " is missing kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); } } // Fallback encoding. const size_t chunk_length = model_->GetOperandVariableWidthChunkLength(operand_.type); if (chunk_length) { writer_.WriteVariableWidthU32(word, chunk_length); } else { writer_.WriteUnencoded(word); } return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode, uint32_t num_operands) { uint64_t bits = 0; size_t num_bits = 0; const uint32_t word = opcode | (num_operands << 16); // First try to use the Markov chain codec. auto* codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); if (codec) { if (codec->Encode(word, &bits, &num_bits)) { // The word was successfully encoded into bits/num_bits. writer_.WriteBits(bits, num_bits); return SPV_SUCCESS; } else { // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove // and use fallback encoding. if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) return Diag(SPV_ERROR_INTERNAL) << "opcode_and_num_operands Huffman table for " << spvOpcodeString(GetPrevOpcode()) << "is missing kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); } } // Fallback to base-rate codec. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); assert(codec); if (codec->Encode(word, &bits, &num_bits)) { // The word was successfully encoded into bits/num_bits. writer_.WriteBits(bits, num_bits); return SPV_SUCCESS; } else { // The word is not in the Huffman table. Write kMarkvNoneOfTheAbove // and return false. if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) return Diag(SPV_ERROR_INTERNAL) << "Global opcode_and_num_operands Huffman table is missing " << "kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); return SPV_UNSUPPORTED; } } spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf, uint64_t fallback_method) { const auto* codec = GetMtfHuffmanCodec(mtf); if (!codec) { assert(fallback_method != kMtfNone); codec = GetMtfHuffmanCodec(fallback_method); } if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank"; uint64_t bits = 0; size_t num_bits = 0; if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) { // Encode using Huffman coding. if (!codec->Encode(rank, &bits, &num_bits)) return Diag(SPV_ERROR_INTERNAL) << "Failed to encode MTF rank with Huffman"; writer_.WriteBits(bits, num_bits); } else { // Encode by value. if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits, &num_bits)) return Diag(SPV_ERROR_INTERNAL) << "Failed to encode kMtfRankEncodedByValueSignal"; writer_.WriteBits(bits, num_bits); writer_.WriteVariableWidthU32( rank - MarkvCodec::kMtfSmallestRankEncodedByValue, model_->mtf_rank_chunk_length()); } return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) { // Get the descriptor for id. const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id); auto* codec = model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); uint64_t bits = 0; size_t num_bits = 0; uint64_t mtf = kMtfNone; if (long_descriptor && codec && codec->Encode(long_descriptor, &bits, &num_bits)) { // If the descriptor exists and is in the table, write the descriptor and // proceed to encoding the rank. writer_.WriteBits(bits, num_bits); mtf = GetMtfLongIdDescriptor(long_descriptor); } else { if (codec) { // The descriptor doesn't exist or we have no coding for it. Write // kMarkvNoneOfTheAbove and go to fallback method. if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits)) return Diag(SPV_ERROR_INTERNAL) << "Descriptor Huffman table for " << spvOpcodeString(SpvOp(inst_.opcode)) << " operand index " << operand_index_ << " is missing kMarkvNoneOfTheAbove"; writer_.WriteBits(bits, num_bits); } if (model_->id_fallback_strategy() != MarkvModel::IdFallbackStrategy::kShortDescriptor) { return SPV_UNSUPPORTED; } const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id); writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits); if (short_descriptor == 0) { // Forward declared id. return SPV_UNSUPPORTED; } mtf = GetMtfShortIdDescriptor(short_descriptor); } // Descriptor has been encoded. Now encode the rank of the id in the // associated mtf sequence. return EncodeExistingId(mtf, id); } spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) { assert(multi_mtf_.GetSize(mtf) > 0); if (multi_mtf_.GetSize(mtf) == 1) { // If the sequence has only one element no need to write rank, the decoder // would make the same decision. return SPV_SUCCESS; } uint32_t rank = 0; if (!multi_mtf_.RankFromValue(mtf, id, &rank)) return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence"; return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank); } spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) { { // Try to encode using id descriptor mtfs. const spv_result_t result = EncodeIdWithDescriptor(id); if (result != SPV_UNSUPPORTED) return result; // If can't be done continue with other methods. } const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( SpvOp(inst_.opcode))(operand_index_); uint32_t rank = 0; if (model_->id_fallback_strategy() == MarkvModel::IdFallbackStrategy::kRuleBased) { // Encode using rule-based mtf. uint64_t mtf = GetRuleBasedMtf(); if (mtf != kMtfNone && !can_forward_declare) { assert(multi_mtf_.HasValue(kMtfAll, id)); return EncodeExistingId(mtf, id); } if (mtf == kMtfNone) mtf = kMtfAll; if (!multi_mtf_.RankFromValue(mtf, id, &rank)) { // This is the first occurrence of a forward declared id. multi_mtf_.Insert(kMtfAll, id); multi_mtf_.Insert(kMtfForwardDeclared, id); if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id); rank = 0; } return EncodeMtfRankHuffman(rank, mtf, kMtfAll); } else { assert(can_forward_declare); if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) { // This is the first occurrence of a forward declared id. multi_mtf_.Insert(kMtfForwardDeclared, id); rank = 0; } writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); return SPV_SUCCESS; } } spv_result_t MarkvEncoder::EncodeTypeId() { if (inst_.opcode == SpvOpFunctionParameter) { assert(!remaining_function_parameter_types_.empty()); assert(inst_.type_id == remaining_function_parameter_types_.front()); remaining_function_parameter_types_.pop_front(); return SPV_SUCCESS; } { // Try to encode using id descriptor mtfs. const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id); if (result != SPV_UNSUPPORTED) return result; // If can't be done continue with other methods. } assert(model_->id_fallback_strategy() == MarkvModel::IdFallbackStrategy::kRuleBased); uint64_t mtf = GetRuleBasedMtf(); assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( operand_index_)); if (mtf == kMtfNone) { mtf = kMtfTypeNonFunction; // Function types should have been handled by GetRuleBasedMtf. assert(inst_.opcode != SpvOpFunction); } return EncodeExistingId(mtf, inst_.type_id); } spv_result_t MarkvEncoder::EncodeResultId() { uint32_t rank = 0; const uint64_t num_still_forward_declared = multi_mtf_.GetSize(kMtfForwardDeclared); if (num_still_forward_declared) { // We write the rank only if kMtfForwardDeclared is not empty. If it is // empty the decoder knows that there are no forward declared ids to expect. if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) { // This is a definition of a forward declared id. We can remove the id // from kMtfForwardDeclared. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) return Diag(SPV_ERROR_INTERNAL) << "Failed to remove id from kMtfForwardDeclared"; writer_.WriteBits(1, 1); writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length()); } else { rank = 0; writer_.WriteBits(0, 1); } } if (model_->id_fallback_strategy() == MarkvModel::IdFallbackStrategy::kRuleBased) { if (!rank) { multi_mtf_.Insert(kMtfAll, inst_.result_id); } } return SPV_SUCCESS; } spv_result_t MarkvEncoder::EncodeLiteralNumber( const spv_parsed_operand_t& operand) { if (operand.number_bit_width <= 32) { const uint32_t word = inst_.words[operand.offset]; return EncodeNonIdWord(word); } else { assert(operand.number_bit_width <= 64); const uint64_t word = uint64_t(inst_.words[operand.offset]) | (uint64_t(inst_.words[operand.offset + 1]) << 32); if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { int64_t val = 0; std::memcpy(&val, &word, 8); writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), model_->s64_block_exponent()); } else if (operand.number_kind == SPV_NUMBER_FLOATING) { writer_.WriteUnencoded(word); } else { return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; } } return SPV_SUCCESS; } void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) { const size_t num_bits_to_next_byte = GetNumBitsToNextByte(writer_.GetNumBits()); if (num_bits_to_next_byte == 0 || num_bits_to_next_byte > byte_break_if_less_than) return; if (logger_) { logger_->AppendWhitespaces(kCommentNumWhitespaces); logger_->AppendText("<byte break>"); } writer_.WriteBits(0, num_bits_to_next_byte); } spv_result_t MarkvEncoder::EncodeInstruction( const spv_parsed_instruction_t& inst) { SpvOp opcode = SpvOp(inst.opcode); inst_ = inst; LogDisassemblyInstruction(); const spv_result_t opcode_encodig_result = EncodeOpcodeAndNumOperands(opcode, inst.num_operands); if (opcode_encodig_result < 0) return opcode_encodig_result; if (opcode_encodig_result != SPV_SUCCESS) { // Fallback encoding for opcode and num_operands. writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length()); if (!OpcodeHasFixedNumberOfOperands(opcode)) { // If the opcode has a variable number of operands, encode the number of // operands with the instruction. if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces); writer_.WriteVariableWidthU16(inst.num_operands, model_->num_operands_chunk_length()); } } // Write operands. const uint32_t num_operands = inst_.num_operands; for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) { operand_ = inst_.operands[operand_index_]; if (logger_) { logger_->AppendWhitespaces(kCommentNumWhitespaces); logger_->AppendText("<"); logger_->AppendText(spvOperandTypeStr(operand_.type)); logger_->AppendText(">"); } switch (operand_.type) { case SPV_OPERAND_TYPE_RESULT_ID: case SPV_OPERAND_TYPE_TYPE_ID: case SPV_OPERAND_TYPE_ID: case SPV_OPERAND_TYPE_OPTIONAL_ID: case SPV_OPERAND_TYPE_SCOPE_ID: case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { const uint32_t id = inst_.words[operand_.offset]; if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) { const spv_result_t result = EncodeTypeId(); if (result != SPV_SUCCESS) return result; } else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) { const spv_result_t result = EncodeResultId(); if (result != SPV_SUCCESS) return result; } else { const spv_result_t result = EncodeRefId(id); if (result != SPV_SUCCESS) return result; } PromoteIfNeeded(id); break; } case SPV_OPERAND_TYPE_LITERAL_INTEGER: { const spv_result_t result = EncodeNonIdWord(inst_.words[operand_.offset]); if (result != SPV_SUCCESS) return result; break; } case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: { const spv_result_t result = EncodeLiteralNumber(operand_); if (result != SPV_SUCCESS) return result; break; } case SPV_OPERAND_TYPE_LITERAL_STRING: { const char* src = reinterpret_cast<const char*>(&inst_.words[operand_.offset]); auto* codec = model_->GetLiteralStringHuffmanCodec(opcode); if (codec) { uint64_t bits = 0; size_t num_bits = 0; const std::string str = src; if (codec->Encode(str, &bits, &num_bits)) { writer_.WriteBits(bits, num_bits); break; } else { bool result = codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits); (void)result; assert(result); writer_.WriteBits(bits, num_bits); } } const size_t length = spv_strnlen_s(src, operand_.num_words * 4); if (length == operand_.num_words * 4) return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to find terminal character of literal string"; for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]); break; } default: { for (int i = 0; i < operand_.num_words; ++i) { const uint32_t word = inst_.words[operand_.offset + i]; const spv_result_t result = EncodeNonIdWord(word); if (result != SPV_SUCCESS) return result; } break; } } } AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte); if (logger_) { logger_->NewLine(); logger_->NewLine(); if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; } ProcessCurInstruction(); return SPV_SUCCESS; } } // namespace comp } // namespace spvtools