// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/comp/markv.h"

#include "source/comp/markv_decoder.h"
#include "source/comp/markv_encoder.h"

namespace spvtools {
namespace comp {
namespace {

spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
                          uint32_t magic, uint32_t version, uint32_t generator,
                          uint32_t id_bound, uint32_t schema) {
  MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
  return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
                               schema);
}

spv_result_t EncodeInstruction(void* user_data,
                               const spv_parsed_instruction_t* inst) {
  MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
  return encoder->EncodeInstruction(*inst);
}

}  // namespace

spv_result_t SpirvToMarkv(
    spv_const_context context, const std::vector<uint32_t>& spirv,
    const MarkvCodecOptions& options, const MarkvModel& markv_model,
    MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
    MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
  spv_context_t hijack_context = *context;
  SetContextMessageConsumer(&hijack_context, message_consumer);

  spv_validator_options validator_options =
      MarkvDecoder::GetValidatorOptions(options);
  if (validator_options) {
    spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
    const spv_result_t result = spvValidateWithOptions(
        &hijack_context, validator_options, &spirv_binary, nullptr);
    if (result != SPV_SUCCESS) return result;
  }

  MarkvEncoder encoder(&hijack_context, options, &markv_model);

  spv_position_t position = {};
  if (log_consumer || debug_consumer) {
    encoder.CreateLogger(log_consumer, debug_consumer);

    spv_text text = nullptr;
    if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
                        SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
                        nullptr) != SPV_SUCCESS) {
      return DiagnosticStream(position, hijack_context.consumer, "",
                              SPV_ERROR_INVALID_BINARY)
             << "Failed to disassemble SPIR-V binary.";
    }
    assert(text);
    encoder.SetDisassembly(std::string(text->str, text->length));
    spvTextDestroy(text);
  }

  if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
                     EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
    return DiagnosticStream(position, hijack_context.consumer, "",
                            SPV_ERROR_INVALID_BINARY)
           << "Unable to encode to MARK-V.";
  }

  *markv = encoder.GetMarkvBinary();
  return SPV_SUCCESS;
}

spv_result_t MarkvToSpirv(
    spv_const_context context, const std::vector<uint8_t>& markv,
    const MarkvCodecOptions& options, const MarkvModel& markv_model,
    MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
    MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
  spv_position_t position = {};
  spv_context_t hijack_context = *context;
  SetContextMessageConsumer(&hijack_context, message_consumer);

  MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);

  if (log_consumer || debug_consumer)
    decoder.CreateLogger(log_consumer, debug_consumer);

  if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
    return DiagnosticStream(position, hijack_context.consumer, "",
                            SPV_ERROR_INVALID_BINARY)
           << "Unable to decode MARK-V.";
  }

  assert(!spirv->empty());
  return SPV_SUCCESS;
}

}  // namespace comp
}  // namespace spvtools