"""Qnt primitive used by the GEMM function.

"""

import neon_emitter


class Error(Exception):
  """Module level error."""


class ConfigurationError(Error):
  """Unsupported configuration."""


class QntLane(object):

  def __init__(self, source, output, offset, load_1, load_2):
    self.source = source
    self.output = output
    self.offset = offset
    self.load_1 = load_1
    self.load_2 = load_2


def BuildName(lanes, leftovers, aligned):
  name = 'qnt_%dx8' % lanes
  if leftovers:
    name += '_%d' % leftovers
  if aligned:
    name += '_aligned'
  return name


def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets):
  if lanes == 1 or lanes == 2 or lanes == 3:
    offset_registers = []
    for unused_i in range(0, lanes):
      register = registers.QuadRegister()
      emitter.EmitVLoadA('1.32',
                         [emitter.AllLanes(registers.Low(register)),
                          emitter.AllLanes(registers.High(register))],
                         emitter.DereferenceIncrement(offsets, 32))
      offset_registers.append(register)
    return offset_registers
  else:
    raise ConfigurationError('Unsupported number of lanes: %d' % lanes)


def GenerateQntLanes(emitter,
                     registers,
                     qnt_lanes,
                     source,
                     stride,
                     destination,
                     destination_stride,
                     offsets):
  """Prepare lanes for reading unquantized multiplication results."""
  offset_registers = LoadAndDuplicateOffsets(
      emitter, registers, qnt_lanes, offsets)

  lanes = []
  last_input_register = source
  last_output_register = destination
  for i in range(0, qnt_lanes):
    if not i:
      lanes.append(QntLane(source,
                           destination,
                           offset_registers[i],
                           registers.QuadRegister(),  # load 1
                           registers.QuadRegister()))  # load 2
    else:
      input_register = registers.GeneralRegister()
      output_register = registers.GeneralRegister()
      lanes.append(QntLane(input_register,
                           output_register,
                           offset_registers[i],
                           registers.QuadRegister(),  # load 1
                           registers.QuadRegister()))  # load 2
      emitter.EmitAdd(input_register, last_input_register, stride)
      emitter.EmitAdd(output_register, last_output_register, destination_stride)
      last_input_register = input_register
      last_output_register = output_register
  return lanes


def DuplicateRegister(emitter, registers, value):
  register = registers.QuadRegister()
  emitter.EmitVDup('32', register, value)
  return register


def GenerateQuantize(emitter,
                     registers,
                     lanes,
                     lane_temps,
                     multiplicative_offset,
                     rounding_offset,
                     shift):
  """Inner loop for quantization: add offsets, multiply, round, shift."""
  for lane in lanes:
    emitter.EmitVAdd('i32', lane[0], lane[0], lane[1])

  for lane in lanes:
    emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset)

  for lane in lanes:
    emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset)

  for lane in lanes:
    emitter.EmitVShl('s32', lane[0], lane[0], shift)

  for lane in lanes:
    emitter.EmitVQmovn('s32', lane[2], lane[0])

  for lane_temp in lane_temps:
    emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp)


def GenerateLoadQuantizeStore(emitter,
                              registers,
                              lanes,
                              multiplicative_offset,
                              rounding_offset,
                              shift,
                              alignment):
  """Load unquantized data from lanes, quantize, store final result."""
  lane_temps = []
  for lane in lanes:
    lane_temps.append(registers.QuadRegister())

  for lane in lanes:
    emitter.EmitVLoadA('1.32',
                       [registers.Low(lane.load_1),
                        registers.High(lane.load_1),
                        registers.Low(lane.load_2),
                        registers.High(lane.load_2)],
                       emitter.DereferenceIncrement(lane.source, 64))

  for lane in lanes:
    emitter.EmitPld(lane.source)

  quantize_setup = []
  for (lane_temp, lane) in zip(lane_temps, lanes):
    quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
    quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)])

  GenerateQuantize(emitter,
                   registers,
                   quantize_setup,
                   lane_temps,
                   multiplicative_offset,
                   rounding_offset,
                   shift)

  for (lane_temp, lane) in zip(lane_temps, lanes):
    emitter.EmitVStore('1.8',
                       registers.Low(lane_temp),
                       emitter.DereferenceIncrement(lane.output, alignment))

  for lane_temp in lane_temps:
    registers.FreeRegister(lane_temp)


def GenerateLoadLeftovers(emitter, registers, leftovers, lanes):
  """Handle non multiply of 8 leftover loading."""
  if leftovers == 1:
    for lane in lanes:
      emitter.EmitVLoad('1.32',
                        emitter.Lane(registers.Low(lane.load_1), 0),
                        emitter.Dereference(lane.source, None))
  elif leftovers == 2:
    for lane in lanes:
      emitter.EmitVLoad('1.32',
                        registers.Low(lane.load_1),
                        emitter.Dereference(lane.source, 64))
  elif leftovers == 3:
    for lane in lanes:
      emitter.EmitVLoad('1.32',
                        registers.Low(lane.load_1),
                        emitter.DereferenceIncrement(lane.source, 64))
    for lane in lanes:
      emitter.EmitVLoad('1.32',
                        emitter.Lane(registers.High(lane.load_1), 0),
                        emitter.Dereference(lane.source, None))
  elif leftovers == 4:
    for lane in lanes:
      emitter.EmitVLoadA('1.32',
                         [registers.Low(lane.load_1),
                          registers.High(lane.load_1)],
                         emitter.Dereference(lane.source, 64))
  elif leftovers == 5:
    for lane in lanes:
      emitter.EmitVLoadA('1.32',
                         [registers.Low(lane.load_1),
                          registers.High(lane.load_1)],
                         emitter.DereferenceIncrement(lane.source, 64))
    for lane in lanes:
      emitter.EmitVLoad('1.32',
                        emitter.Lane(registers.Low(lane.load_2), 0),
                        emitter.Dereference(lane.source, None))
  elif leftovers == 6:
    for lane in lanes:
      emitter.EmitVLoadA('1.32',
                         [registers.Low(lane.load_1),
                          registers.High(lane.load_1),
                          registers.Low(lane.load_2)],
                         emitter.Dereference(lane.source, 64))
  elif leftovers == 7:
    for lane in lanes:
      emitter.EmitVLoadA('1.32',
                         [registers.Low(lane.load_1),
                          registers.High(lane.load_1),
                          registers.Low(lane.load_2)],
                         emitter.DereferenceIncrement(lane.source, 64))
    for lane in lanes:
      emitter.EmitVLoad('1.32',
                        emitter.Lane(registers.High(lane.load_2), 0),
                        emitter.Dereference(lane.source, None))
  else:
    raise ConfigurationError('Unsuported leftover count: %d' % leftovers)


def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes):
  """Handle non multiply of 8 leftover storing."""
  setup = []
  for (temp, lane) in zip(lane_temps, lanes):
    setup.append([registers.Low(temp), lane.output])

  if leftovers == 1:
    for lane in setup:
      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0),
                         emitter.Dereference(lane[1], None))
  elif leftovers == 2:
    for lane in setup:
      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
                         emitter.Dereference(lane[1], None))
  elif leftovers == 3:
    for lane in setup:
      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
                         emitter.DereferenceIncrement(lane[1], None))
    for lane in setup:
      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2),
                         emitter.Dereference(lane[1], None))
  elif leftovers == 4:
    for lane in setup:
      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
                         emitter.Dereference(lane[1], None))
  elif leftovers == 5:
    for lane in setup:
      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
                         emitter.DereferenceIncrement(lane[1], None))
    for lane in setup:
      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4),
                         emitter.Dereference(lane[1], None))
  elif leftovers == 6:
    for lane in setup:
      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
                         emitter.DereferenceIncrement(lane[1], None))
    for lane in setup:
      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
                         emitter.Dereference(lane[1], None))
  elif leftovers == 7:
    for lane in setup:
      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
                         emitter.DereferenceIncrement(lane[1], None))
    for lane in setup:
      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
                         emitter.DereferenceIncrement(lane[1], None))
    for lane in setup:
      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6),
                         emitter.DereferenceIncrement(lane[1], None))
  else:
    raise ConfigurationError('Unsupported leftovers count: %d' % leftovers)


def GenerateLeftoverLoadQuantizeStore(emitter,
                                      registers,
                                      leftovers,
                                      lanes,
                                      multiplicative_offset,
                                      rounding_offset,
                                      shift):
  """Handle leftovers if row size not a multiply of 8."""
  lane_temps = []
  for lane in lanes:
    lane_temps.append(registers.QuadRegister())

  GenerateLoadLeftovers(emitter, registers, leftovers, lanes)

  quantize_setup = []
  for (lane_temp, lane) in zip(lane_temps, lanes):
    quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
    if leftovers > 4:
      quantize_setup.append(
          [lane.load_2, lane.offset, registers.High(lane_temp)])

  GenerateQuantize(emitter,
                   registers,
                   quantize_setup,
                   lane_temps,
                   multiplicative_offset,
                   rounding_offset,
                   shift)

  GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes)


def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned):
  """Emits optimized quantization code for given lanes and row size."""
  if leftovers < 0 or leftovers > 7:
    raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
  if qnt_lanes < 1 or qnt_lanes > 3:
    raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.')

  name = BuildName(qnt_lanes, leftovers, aligned)

  emitter.EmitFunctionBeginA(name,
                             [['const std::int32_t*', 'source'],
                              ['std::int32_t', 'count'],
                              ['std::int32_t', 'stride'],
                              ['const std::int32_t*', 'offsets'],
                              ['std::uint8_t*', 'destination'],
                              ['std::int32_t', 'destination_stride'],
                              ['std::int32_t', 'multiplicative_offset'],
                              ['std::int32_t', 'rounding_offset'],
                              ['std::int32_t', 'shift']],
                             'void')
  emitter.EmitAssert('count %% 8 == %d' % leftovers)
  emitter.EmitAssert('count >= 8')
  emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
  if aligned:
    emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
    if qnt_lanes > 1:
      emitter.EmitAssert(
          'destination_stride % 8 == 0')
  emitter.EmitAsmBegin()

  registers = neon_emitter.NeonRegisters()

  count = registers.MapParameter('count')

  multiplicative_offset = DuplicateRegister(
      emitter, registers, registers.MapParameter('multiplicative_offset'))
  rounding_offset = DuplicateRegister(
      emitter, registers, registers.MapParameter('rounding_offset'))
  shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift'))

  lanes = GenerateQntLanes(
      emitter, registers, qnt_lanes,
      registers.MapParameter('source'),
      registers.MapParameter('stride'),
      registers.MapParameter('destination'),
      registers.MapParameter('destination_stride'),
      registers.MapParameter('offsets'))

  if leftovers:
    emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers))
    emitter.EmitBeqFront(2)

  emitter.EmitNewline()
  emitter.EmitNumericalLabel(1)
  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))

  GenerateLoadQuantizeStore(emitter,
                            registers,
                            lanes,
                            multiplicative_offset,
                            rounding_offset,
                            shift,
                            64 if aligned else None)

  emitter.EmitNewline()
  emitter.EmitBneBack(1)

  if leftovers:
    emitter.EmitNumericalLabel(2)
    GenerateLeftoverLoadQuantizeStore(emitter,
                                      registers,
                                      leftovers,
                                      lanes,
                                      multiplicative_offset,
                                      rounding_offset,
                                      shift)

  emitter.EmitAsmEnd(registers.MappedParameters(),
                     [],
                     registers.Clobbers() + ['cc', 'memory'])
  emitter.EmitFunctionEnd()


def GenerateFunctions(emitter):
  for aligned in [True, False]:
    for lanes in range(1, 4):
      for leftovers in range(0, 8):
        GenerateQntNx8(emitter, lanes, leftovers, aligned)
        emitter.EmitNewline()