// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv_encoder.h"
#include "source/binary.h"
#include "source/opcode.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
namespace comp {
namespace {
const size_t kCommentNumWhitespaces = 2;
} // namespace
spv_result_t MarkvEncoder::EncodeNonIdWord(uint32_t word) {
auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
if (codec) {
uint64_t bits = 0;
size_t num_bits = 0;
if (codec->Encode(word, &bits, &num_bits)) {
// Encoding successful.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// Encoding failed, write kMarkvNoneOfTheAbove flag.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Non-id word Huffman table for "
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
}
// Fallback encoding.
const size_t chunk_length =
model_->GetOperandVariableWidthChunkLength(operand_.type);
if (chunk_length) {
writer_.WriteVariableWidthU32(word, chunk_length);
} else {
writer_.WriteUnencoded(word);
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeOpcodeAndNumOperands(uint32_t opcode,
uint32_t num_operands) {
uint64_t bits = 0;
size_t num_bits = 0;
const uint32_t word = opcode | (num_operands << 16);
// First try to use the Markov chain codec.
auto* codec =
model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
if (codec) {
if (codec->Encode(word, &bits, &num_bits)) {
// The word was successfully encoded into bits/num_bits.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
// and use fallback encoding.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "opcode_and_num_operands Huffman table for "
<< spvOpcodeString(GetPrevOpcode())
<< "is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
}
// Fallback to base-rate codec.
codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
assert(codec);
if (codec->Encode(word, &bits, &num_bits)) {
// The word was successfully encoded into bits/num_bits.
writer_.WriteBits(bits, num_bits);
return SPV_SUCCESS;
} else {
// The word is not in the Huffman table. Write kMarkvNoneOfTheAbove
// and return false.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits, &num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Global opcode_and_num_operands Huffman table is missing "
<< "kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
return SPV_UNSUPPORTED;
}
}
spv_result_t MarkvEncoder::EncodeMtfRankHuffman(uint32_t rank, uint64_t mtf,
uint64_t fallback_method) {
const auto* codec = GetMtfHuffmanCodec(mtf);
if (!codec) {
assert(fallback_method != kMtfNone);
codec = GetMtfHuffmanCodec(fallback_method);
}
if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to encode MTF rank";
uint64_t bits = 0;
size_t num_bits = 0;
if (rank < MarkvCodec::kMtfSmallestRankEncodedByValue) {
// Encode using Huffman coding.
if (!codec->Encode(rank, &bits, &num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to encode MTF rank with Huffman";
writer_.WriteBits(bits, num_bits);
} else {
// Encode by value.
if (!codec->Encode(MarkvCodec::kMtfRankEncodedByValueSignal, &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to encode kMtfRankEncodedByValueSignal";
writer_.WriteBits(bits, num_bits);
writer_.WriteVariableWidthU32(
rank - MarkvCodec::kMtfSmallestRankEncodedByValue,
model_->mtf_rank_chunk_length());
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeIdWithDescriptor(uint32_t id) {
// Get the descriptor for id.
const uint32_t long_descriptor = long_id_descriptors_.GetDescriptor(id);
auto* codec =
model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
uint64_t bits = 0;
size_t num_bits = 0;
uint64_t mtf = kMtfNone;
if (long_descriptor && codec &&
codec->Encode(long_descriptor, &bits, &num_bits)) {
// If the descriptor exists and is in the table, write the descriptor and
// proceed to encoding the rank.
writer_.WriteBits(bits, num_bits);
mtf = GetMtfLongIdDescriptor(long_descriptor);
} else {
if (codec) {
// The descriptor doesn't exist or we have no coding for it. Write
// kMarkvNoneOfTheAbove and go to fallback method.
if (!codec->Encode(MarkvModel::GetMarkvNoneOfTheAbove(), &bits,
&num_bits))
return Diag(SPV_ERROR_INTERNAL)
<< "Descriptor Huffman table for "
<< spvOpcodeString(SpvOp(inst_.opcode)) << " operand index "
<< operand_index_ << " is missing kMarkvNoneOfTheAbove";
writer_.WriteBits(bits, num_bits);
}
if (model_->id_fallback_strategy() !=
MarkvModel::IdFallbackStrategy::kShortDescriptor) {
return SPV_UNSUPPORTED;
}
const uint32_t short_descriptor = short_id_descriptors_.GetDescriptor(id);
writer_.WriteBits(short_descriptor, MarkvCodec::kShortDescriptorNumBits);
if (short_descriptor == 0) {
// Forward declared id.
return SPV_UNSUPPORTED;
}
mtf = GetMtfShortIdDescriptor(short_descriptor);
}
// Descriptor has been encoded. Now encode the rank of the id in the
// associated mtf sequence.
return EncodeExistingId(mtf, id);
}
spv_result_t MarkvEncoder::EncodeExistingId(uint64_t mtf, uint32_t id) {
assert(multi_mtf_.GetSize(mtf) > 0);
if (multi_mtf_.GetSize(mtf) == 1) {
// If the sequence has only one element no need to write rank, the decoder
// would make the same decision.
return SPV_SUCCESS;
}
uint32_t rank = 0;
if (!multi_mtf_.RankFromValue(mtf, id, &rank))
return Diag(SPV_ERROR_INTERNAL) << "Id is not in the MTF sequence";
return EncodeMtfRankHuffman(rank, mtf, kMtfGenericNonZeroRank);
}
spv_result_t MarkvEncoder::EncodeRefId(uint32_t id) {
{
// Try to encode using id descriptor mtfs.
const spv_result_t result = EncodeIdWithDescriptor(id);
if (result != SPV_UNSUPPORTED) return result;
// If can't be done continue with other methods.
}
const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
SpvOp(inst_.opcode))(operand_index_);
uint32_t rank = 0;
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
// Encode using rule-based mtf.
uint64_t mtf = GetRuleBasedMtf();
if (mtf != kMtfNone && !can_forward_declare) {
assert(multi_mtf_.HasValue(kMtfAll, id));
return EncodeExistingId(mtf, id);
}
if (mtf == kMtfNone) mtf = kMtfAll;
if (!multi_mtf_.RankFromValue(mtf, id, &rank)) {
// This is the first occurrence of a forward declared id.
multi_mtf_.Insert(kMtfAll, id);
multi_mtf_.Insert(kMtfForwardDeclared, id);
if (mtf != kMtfAll) multi_mtf_.Insert(mtf, id);
rank = 0;
}
return EncodeMtfRankHuffman(rank, mtf, kMtfAll);
} else {
assert(can_forward_declare);
if (!multi_mtf_.RankFromValue(kMtfForwardDeclared, id, &rank)) {
// This is the first occurrence of a forward declared id.
multi_mtf_.Insert(kMtfForwardDeclared, id);
rank = 0;
}
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
return SPV_SUCCESS;
}
}
spv_result_t MarkvEncoder::EncodeTypeId() {
if (inst_.opcode == SpvOpFunctionParameter) {
assert(!remaining_function_parameter_types_.empty());
assert(inst_.type_id == remaining_function_parameter_types_.front());
remaining_function_parameter_types_.pop_front();
return SPV_SUCCESS;
}
{
// Try to encode using id descriptor mtfs.
const spv_result_t result = EncodeIdWithDescriptor(inst_.type_id);
if (result != SPV_UNSUPPORTED) return result;
// If can't be done continue with other methods.
}
assert(model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased);
uint64_t mtf = GetRuleBasedMtf();
assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
operand_index_));
if (mtf == kMtfNone) {
mtf = kMtfTypeNonFunction;
// Function types should have been handled by GetRuleBasedMtf.
assert(inst_.opcode != SpvOpFunction);
}
return EncodeExistingId(mtf, inst_.type_id);
}
spv_result_t MarkvEncoder::EncodeResultId() {
uint32_t rank = 0;
const uint64_t num_still_forward_declared =
multi_mtf_.GetSize(kMtfForwardDeclared);
if (num_still_forward_declared) {
// We write the rank only if kMtfForwardDeclared is not empty. If it is
// empty the decoder knows that there are no forward declared ids to expect.
if (multi_mtf_.RankFromValue(kMtfForwardDeclared, inst_.result_id, &rank)) {
// This is a definition of a forward declared id. We can remove the id
// from kMtfForwardDeclared.
if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
return Diag(SPV_ERROR_INTERNAL)
<< "Failed to remove id from kMtfForwardDeclared";
writer_.WriteBits(1, 1);
writer_.WriteVariableWidthU32(rank, model_->mtf_rank_chunk_length());
} else {
rank = 0;
writer_.WriteBits(0, 1);
}
}
if (model_->id_fallback_strategy() ==
MarkvModel::IdFallbackStrategy::kRuleBased) {
if (!rank) {
multi_mtf_.Insert(kMtfAll, inst_.result_id);
}
}
return SPV_SUCCESS;
}
spv_result_t MarkvEncoder::EncodeLiteralNumber(
const spv_parsed_operand_t& operand) {
if (operand.number_bit_width <= 32) {
const uint32_t word = inst_.words[operand.offset];
return EncodeNonIdWord(word);
} else {
assert(operand.number_bit_width <= 64);
const uint64_t word = uint64_t(inst_.words[operand.offset]) |
(uint64_t(inst_.words[operand.offset + 1]) << 32);
if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
} else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
int64_t val = 0;
std::memcpy(&val, &word, 8);
writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
model_->s64_block_exponent());
} else if (operand.number_kind == SPV_NUMBER_FLOATING) {
writer_.WriteUnencoded(word);
} else {
return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
}
}
return SPV_SUCCESS;
}
void MarkvEncoder::AddByteBreak(size_t byte_break_if_less_than) {
const size_t num_bits_to_next_byte =
GetNumBitsToNextByte(writer_.GetNumBits());
if (num_bits_to_next_byte == 0 ||
num_bits_to_next_byte > byte_break_if_less_than)
return;
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("<byte break>");
}
writer_.WriteBits(0, num_bits_to_next_byte);
}
spv_result_t MarkvEncoder::EncodeInstruction(
const spv_parsed_instruction_t& inst) {
SpvOp opcode = SpvOp(inst.opcode);
inst_ = inst;
LogDisassemblyInstruction();
const spv_result_t opcode_encodig_result =
EncodeOpcodeAndNumOperands(opcode, inst.num_operands);
if (opcode_encodig_result < 0) return opcode_encodig_result;
if (opcode_encodig_result != SPV_SUCCESS) {
// Fallback encoding for opcode and num_operands.
writer_.WriteVariableWidthU32(opcode, model_->opcode_chunk_length());
if (!OpcodeHasFixedNumberOfOperands(opcode)) {
// If the opcode has a variable number of operands, encode the number of
// operands with the instruction.
if (logger_) logger_->AppendWhitespaces(kCommentNumWhitespaces);
writer_.WriteVariableWidthU16(inst.num_operands,
model_->num_operands_chunk_length());
}
}
// Write operands.
const uint32_t num_operands = inst_.num_operands;
for (operand_index_ = 0; operand_index_ < num_operands; ++operand_index_) {
operand_ = inst_.operands[operand_index_];
if (logger_) {
logger_->AppendWhitespaces(kCommentNumWhitespaces);
logger_->AppendText("<");
logger_->AppendText(spvOperandTypeStr(operand_.type));
logger_->AppendText(">");
}
switch (operand_.type) {
case SPV_OPERAND_TYPE_RESULT_ID:
case SPV_OPERAND_TYPE_TYPE_ID:
case SPV_OPERAND_TYPE_ID:
case SPV_OPERAND_TYPE_OPTIONAL_ID:
case SPV_OPERAND_TYPE_SCOPE_ID:
case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
const uint32_t id = inst_.words[operand_.offset];
if (operand_.type == SPV_OPERAND_TYPE_TYPE_ID) {
const spv_result_t result = EncodeTypeId();
if (result != SPV_SUCCESS) return result;
} else if (operand_.type == SPV_OPERAND_TYPE_RESULT_ID) {
const spv_result_t result = EncodeResultId();
if (result != SPV_SUCCESS) return result;
} else {
const spv_result_t result = EncodeRefId(id);
if (result != SPV_SUCCESS) return result;
}
PromoteIfNeeded(id);
break;
}
case SPV_OPERAND_TYPE_LITERAL_INTEGER: {
const spv_result_t result =
EncodeNonIdWord(inst_.words[operand_.offset]);
if (result != SPV_SUCCESS) return result;
break;
}
case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: {
const spv_result_t result = EncodeLiteralNumber(operand_);
if (result != SPV_SUCCESS) return result;
break;
}
case SPV_OPERAND_TYPE_LITERAL_STRING: {
const char* src =
reinterpret_cast<const char*>(&inst_.words[operand_.offset]);
auto* codec = model_->GetLiteralStringHuffmanCodec(opcode);
if (codec) {
uint64_t bits = 0;
size_t num_bits = 0;
const std::string str = src;
if (codec->Encode(str, &bits, &num_bits)) {
writer_.WriteBits(bits, num_bits);
break;
} else {
bool result =
codec->Encode("kMarkvNoneOfTheAbove", &bits, &num_bits);
(void)result;
assert(result);
writer_.WriteBits(bits, num_bits);
}
}
const size_t length = spv_strnlen_s(src, operand_.num_words * 4);
if (length == operand_.num_words * 4)
return Diag(SPV_ERROR_INVALID_BINARY)
<< "Failed to find terminal character of literal string";
for (size_t i = 0; i < length + 1; ++i) writer_.WriteUnencoded(src[i]);
break;
}
default: {
for (int i = 0; i < operand_.num_words; ++i) {
const uint32_t word = inst_.words[operand_.offset + i];
const spv_result_t result = EncodeNonIdWord(word);
if (result != SPV_SUCCESS) return result;
}
break;
}
}
}
AddByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte);
if (logger_) {
logger_->NewLine();
logger_->NewLine();
if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
}
ProcessCurInstruction();
return SPV_SUCCESS;
}
} // namespace comp
} // namespace spvtools