// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/comp/markv.h"
#include "source/comp/markv_decoder.h"
#include "source/comp/markv_encoder.h"
namespace spvtools {
namespace comp {
namespace {
spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
uint32_t magic, uint32_t version, uint32_t generator,
uint32_t id_bound, uint32_t schema) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
schema);
}
spv_result_t EncodeInstruction(void* user_data,
const spv_parsed_instruction_t* inst) {
MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
return encoder->EncodeInstruction(*inst);
}
} // namespace
spv_result_t SpirvToMarkv(
spv_const_context context, const std::vector<uint32_t>& spirv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
spv_context_t hijack_context = *context;
SetContextMessageConsumer(&hijack_context, message_consumer);
spv_validator_options validator_options =
MarkvDecoder::GetValidatorOptions(options);
if (validator_options) {
spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
const spv_result_t result = spvValidateWithOptions(
&hijack_context, validator_options, &spirv_binary, nullptr);
if (result != SPV_SUCCESS) return result;
}
MarkvEncoder encoder(&hijack_context, options, &markv_model);
spv_position_t position = {};
if (log_consumer || debug_consumer) {
encoder.CreateLogger(log_consumer, debug_consumer);
spv_text text = nullptr;
if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
nullptr) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Failed to disassemble SPIR-V binary.";
}
assert(text);
encoder.SetDisassembly(std::string(text->str, text->length));
spvTextDestroy(text);
}
if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Unable to encode to MARK-V.";
}
*markv = encoder.GetMarkvBinary();
return SPV_SUCCESS;
}
spv_result_t MarkvToSpirv(
spv_const_context context, const std::vector<uint8_t>& markv,
const MarkvCodecOptions& options, const MarkvModel& markv_model,
MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
spv_position_t position = {};
spv_context_t hijack_context = *context;
SetContextMessageConsumer(&hijack_context, message_consumer);
MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
if (log_consumer || debug_consumer)
decoder.CreateLogger(log_consumer, debug_consumer);
if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
return DiagnosticStream(position, hijack_context.consumer, "",
SPV_ERROR_INVALID_BINARY)
<< "Unable to decode MARK-V.";
}
assert(!spirv->empty());
return SPV_SUCCESS;
}
} // namespace comp
} // namespace spvtools