// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SOURCE_COMP_MARKV_CODEC_H_
#define SOURCE_COMP_MARKV_CODEC_H_

#include <list>
#include <map>
#include <memory>
#include <vector>

#include "source/assembly_grammar.h"
#include "source/comp/huffman_codec.h"
#include "source/comp/markv_model.h"
#include "source/comp/move_to_front.h"
#include "source/diagnostic.h"
#include "source/id_descriptor.h"

#include "source/val/instruction.h"

// Base class for MARK-V encoder and decoder. Contains common functionality
// such as:
// - Validator connection and validation state.
// - SPIR-V grammar and helper functions.

namespace spvtools {
namespace comp {

class MarkvLogger;

// Handles for move-to-front sequences. Enums which end with "Begin" define
// handle spaces which start at that value and span 16 or 32 bit wide.
enum : uint64_t {
  kMtfNone = 0,
  // All ids.
  kMtfAll,
  // All forward declared ids.
  kMtfForwardDeclared,
  // All type ids except for generated by OpTypeFunction.
  kMtfTypeNonFunction,
  // All labels.
  kMtfLabel,
  // All ids created by instructions which had type_id.
  kMtfObject,
  // All types generated by OpTypeFloat, OpTypeInt, OpTypeBool.
  kMtfTypeScalar,
  // All composite types.
  kMtfTypeComposite,
  // Boolean type or any vector type of it.
  kMtfTypeBoolScalarOrVector,
  // All float types or any vector floats type.
  kMtfTypeFloatScalarOrVector,
  // All int types or any vector int type.
  kMtfTypeIntScalarOrVector,
  // All types declared as return types in OpTypeFunction.
  kMtfTypeReturnedByFunction,
  // All composite objects.
  kMtfComposite,
  // All bool objects or vectors of bools.
  kMtfBoolScalarOrVector,
  // All float objects or vectors of float.
  kMtfFloatScalarOrVector,
  // All int objects or vectors of int.
  kMtfIntScalarOrVector,
  // All pointer types which point to composited.
  kMtfTypePointerToComposite,
  // Used by EncodeMtfRankHuffman.
  kMtfGenericNonZeroRank,
  // Handle space for ids of specific type.
  kMtfIdOfTypeBegin = 0x10000,
  // Handle space for ids generated by specific opcode.
  kMtfIdGeneratedByOpcode = 0x20000,
  // Handle space for ids of objects with type generated by specific opcode.
  kMtfIdWithTypeGeneratedByOpcodeBegin = 0x30000,
  // All vectors of specific component type.
  kMtfVectorOfComponentTypeBegin = 0x40000,
  // All vector types of specific size.
  kMtfTypeVectorOfSizeBegin = 0x50000,
  // All pointer types to specific type.
  kMtfPointerToTypeBegin = 0x60000,
  // All function types which return specific type.
  kMtfFunctionTypeWithReturnTypeBegin = 0x70000,
  // All function objects which return specific type.
  kMtfFunctionWithReturnTypeBegin = 0x80000,
  // Short id descriptor space (max 16-bit).
  kMtfShortIdDescriptorSpaceBegin = 0x90000,
  // Long id descriptor space (32-bit).
  kMtfLongIdDescriptorSpaceBegin = 0x100000000,
};

class MarkvCodec {
 public:
  static const uint32_t kMarkvMagicNumber;

  // Mtf ranks smaller than this are encoded with Huffman coding.
  static const uint32_t kMtfSmallestRankEncodedByValue;

  // Signals that the mtf rank is too large to be encoded with Huffman.
  static const uint32_t kMtfRankEncodedByValueSignal;

  static const uint32_t kShortDescriptorNumBits;

  static const size_t kByteBreakAfterInstIfLessThanUntilNextByte;

  static uint32_t GetMarkvVersion();

  virtual ~MarkvCodec();

 protected:
  struct MarkvHeader {
    MarkvHeader();

    uint32_t magic_number;
    uint32_t markv_version;
    // Magic number to identify or verify MarkvModel used for encoding.
    uint32_t markv_model = 0;
    uint32_t markv_length_in_bits = 0;
    uint32_t spirv_version = 0;
    uint32_t spirv_generator = 0;
  };

  // |model| is owned by the caller, must be not null and valid during the
  // lifetime of the codec.
  MarkvCodec(spv_const_context context, spv_validator_options validator_options,
             const MarkvModel* model);

  // Returns instruction which created |id| or nullptr if such instruction was
  // not registered.
  const val::Instruction* FindDef(uint32_t id) const {
    const auto it = id_to_def_instruction_.find(id);
    if (it == id_to_def_instruction_.end()) return nullptr;
    return it->second;
  }

  size_t GetNumBitsToNextByte(size_t bit_pos) const;
  bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) const;

  // Returns type id of vector type component.
  uint32_t GetVectorComponentType(uint32_t vector_type_id) const {
    const val::Instruction* type_inst = FindDef(vector_type_id);
    assert(type_inst);
    assert(type_inst->opcode() == SpvOpTypeVector);

    const uint32_t component_type =
        type_inst->word(type_inst->operands()[1].offset);
    return component_type;
  }

  // Returns mtf handle for ids of given type.
  uint64_t GetMtfIdOfType(uint32_t type_id) const {
    return kMtfIdOfTypeBegin + type_id;
  }

  // Returns mtf handle for ids generated by given opcode.
  uint64_t GetMtfIdGeneratedByOpcode(SpvOp opcode) const {
    return kMtfIdGeneratedByOpcode + opcode;
  }

  // Returns mtf handle for ids of type generated by given opcode.
  uint64_t GetMtfIdWithTypeGeneratedByOpcode(SpvOp opcode) const {
    return kMtfIdWithTypeGeneratedByOpcodeBegin + opcode;
  }

  // Returns mtf handle for vectors of specific component type.
  uint64_t GetMtfVectorOfComponentType(uint32_t type_id) const {
    return kMtfVectorOfComponentTypeBegin + type_id;
  }

  // Returns mtf handle for vector type of specific size.
  uint64_t GetMtfTypeVectorOfSize(uint32_t size) const {
    return kMtfTypeVectorOfSizeBegin + size;
  }

  // Returns mtf handle for pointers to specific size.
  uint64_t GetMtfPointerToType(uint32_t type_id) const {
    return kMtfPointerToTypeBegin + type_id;
  }

  // Returns mtf handle for function types with given return type.
  uint64_t GetMtfFunctionTypeWithReturnType(uint32_t type_id) const {
    return kMtfFunctionTypeWithReturnTypeBegin + type_id;
  }

  // Returns mtf handle for functions with given return type.
  uint64_t GetMtfFunctionWithReturnType(uint32_t type_id) const {
    return kMtfFunctionWithReturnTypeBegin + type_id;
  }

  // Returns mtf handle for the given long id descriptor.
  uint64_t GetMtfLongIdDescriptor(uint32_t descriptor) const {
    return kMtfLongIdDescriptorSpaceBegin + descriptor;
  }

  // Returns mtf handle for the given short id descriptor.
  uint64_t GetMtfShortIdDescriptor(uint32_t descriptor) const {
    return kMtfShortIdDescriptorSpaceBegin + descriptor;
  }

  // Process data from the current instruction. This would update MTFs and
  // other data containers.
  void ProcessCurInstruction();

  // Returns move-to-front handle to be used for the current operand slot.
  // Mtf handle is chosen based on a set of rules defined by SPIR-V grammar.
  uint64_t GetRuleBasedMtf();

  // Returns words of the current instruction. Decoder has a different
  // implementation and the array is valid only until the previously decoded
  // word.
  virtual const uint32_t* GetInstWords() const { return inst_.words; }

  // Returns the opcode of the previous instruction.
  SpvOp GetPrevOpcode() const {
    if (instructions_.empty()) return SpvOpNop;

    return instructions_.back()->opcode();
  }

  // Returns diagnostic stream, position index is set to instruction number.
  DiagnosticStream Diag(spv_result_t error_code) const {
    return DiagnosticStream({0, 0, instructions_.size()}, context_->consumer,
                            "", error_code);
  }

  // Returns current id bound.
  uint32_t GetIdBound() const { return id_bound_; }

  // Sets current id bound, expected to be no lower than the previous one.
  void SetIdBound(uint32_t id_bound) {
    assert(id_bound >= id_bound_);
    id_bound_ = id_bound;
  }

  // Returns Huffman codec for ranks of the mtf with given |handle|.
  // Different mtfs can use different rank distributions.
  // May return nullptr if the codec doesn't exist.
  const HuffmanCodec<uint32_t>* GetMtfHuffmanCodec(uint64_t handle) const {
    const auto it = mtf_huffman_codecs_.find(handle);
    if (it == mtf_huffman_codecs_.end()) return nullptr;
    return it->second.get();
  }

  // Promotes id in all move-to-front sequences if ids can be shared by multiple
  // sequences.
  void PromoteIfNeeded(uint32_t id) {
    if (!model_->AnyDescriptorHasCodingScheme() &&
        model_->id_fallback_strategy() ==
            MarkvModel::IdFallbackStrategy::kShortDescriptor) {
      // Move-to-front sequences do not share ids. Nothing to do.
      return;
    }
    multi_mtf_.Promote(id);
  }

  spv_validator_options validator_options_ = nullptr;
  const AssemblyGrammar grammar_;
  MarkvHeader header_;

  // MARK-V model, not owned.
  const MarkvModel* model_ = nullptr;

  // Current instruction, current operand and current operand index.
  spv_parsed_instruction_t inst_;
  spv_parsed_operand_t operand_;
  uint32_t operand_index_;

  // Maps a result ID to its type ID.  By convention:
  //  - a result ID that is a type definition maps to itself.
  //  - a result ID without a type maps to 0.  (E.g. for OpLabel)
  std::unordered_map<uint32_t, uint32_t> id_to_type_id_;

  // Container for all move-to-front sequences.
  MultiMoveToFront multi_mtf_;

  // Id of the current function or zero if outside of function.
  uint32_t cur_function_id_ = 0;

  // Return type of the current function.
  uint32_t cur_function_return_type_ = 0;

  // Remaining function parameter types. This container is filled on OpFunction,
  // and drained on OpFunctionParameter.
  std::list<uint32_t> remaining_function_parameter_types_;

  // List of ids local to the current function.
  std::vector<uint32_t> ids_local_to_cur_function_;

  // List of instructions in the order they are given in the module.
  std::vector<std::unique_ptr<const val::Instruction>> instructions_;

  // Container/computer for long (32-bit) id descriptors.
  IdDescriptorCollection long_id_descriptors_;

  // Container/computer for short id descriptors.
  // Short descriptors are stored in uint32_t, but their actual bit width is
  // defined with kShortDescriptorNumBits.
  // It doesn't seem logical to have a different computer for short id
  // descriptors, since one could actually map/truncate long descriptors.
  // But as short descriptors have collisions, the efficiency of
  // compression depends on the collision pattern, and short descriptors
  // produced by function ShortHashU32Array have been empirically proven to
  // produce better results.
  IdDescriptorCollection short_id_descriptors_;

  // Huffman codecs for move-to-front ranks. The map key is mtf handle. Doesn't
  // need to contain a different codec for every handle as most use one and the
  // same.
  std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
      mtf_huffman_codecs_;

  // If not nullptr, codec will log comments on the compression process.
  std::unique_ptr<MarkvLogger> logger_;

  spv_const_context context_ = nullptr;

 private:
  // Maps result id to the instruction which defined it.
  std::unordered_map<uint32_t, const val::Instruction*> id_to_def_instruction_;

  uint32_t id_bound_ = 1;
};

}  // namespace comp
}  // namespace spvtools

#endif  // SOURCE_COMP_MARKV_CODEC_H_