#!/usr/bin/python3

# Copyright 2017, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Slicing the input Model file

Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is
not intended to be invoked directly by the users. See that script for
details on how to use the slicing tool is used.

This script does the following work:

Perform a topological sort similar to the test generator, except that:
* It would stop at the N-th operation it encounters, and
* Rename the output of the N-th operation to a model output, and
* Name that as the output of the model.
* Also only inputs and weights used by the submodel would be emitted.

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from functools import reduce
import math
import os
import struct
import sys
import contextlib
import test_generator
import pprint
# Stuff from test generator
from test_generator import Configuration
from test_generator import Example
from test_generator import Float32Scalar
from test_generator import Float32Vector
from test_generator import IgnoredOutput
from test_generator import Input
from test_generator import Int32Scalar
from test_generator import Int32Vector
from test_generator import Internal
from test_generator import Model
from test_generator import Output
from test_generator import Parameter
from test_generator import SmartOpen


# Take a model from command line
def import_source():
  parser = argparse.ArgumentParser()
  parser.add_argument("spec", help="the spec file")
  parser.add_argument(
      "-n", "--number",
      help="number of operations in the sliced model. Default = 1",
      default=1)
  parser.add_argument(
      "-m", "--model", help="the output model file", default="-")
  parser.add_argument(
      "-e", "--example", help="the output example file", default="-")
  args = parser.parse_args()

  if os.path.exists(args.spec):
    test_generator.FileNames.specFile = os.path.basename(args.spec)
    exec (open(args.spec).read())
  else:
    print("cannot find file %s" % args.spec)
    sys.exit(1)

  return (args.model, args.example, args.number)


# Slice till the Nth op the topological sort finds
# the output of that op becomes the output of the model
class slicing:

  def __init__(self, threshold):
    self.__nr_op_seen = 0
    self.__threshold = threshold
    self.__last_outs = []
    self.__all_formatted_ops = []
    self.__referenced_operands = set()

  def format_as_py_op(self, op):
    fmt = op.PyDefinition()
    if fmt is not None:
      self.__nr_op_seen += 1
      if self.__nr_op_seen > self.__threshold:
        return False
      self.__last_outs = op.outs
      for o in op.ins:
        self.__referenced_operands.add(o)
      for o in op.outs:
        self.__referenced_operands.add(o)
      self.__all_formatted_ops.append("model = model.%s" % fmt)
      return True

  def dump(self, model_file):
    for x in self.__all_formatted_ops:
      print(x, file=model_file)

  def dump_example(self, example_file):
    override = {}
    # Make alias for the output variable
    for lo in self.__last_outs:
      override[str(lo)] = lo.type.GetNumberOfElements()
      alias_def = """\
# Alias for the output variable {operand_name}
aliased_output{number} = {operand_name}
"""
      op = {
          'operand_name': str(lo),
          'number': 0 # only support one output as of now
      }
      print (alias_def.format(**op), file=example_file)
    Example.py_dump(example_file, override, self.__referenced_operands)

  def format_operands(self, model):
    # Dump operand definitions
    op_definitions = []
    for o in model.operands:
      if o not in self.__referenced_operands:
        continue
      ty = o.type
      op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """
      if isinstance(o, test_generator.Parameter):
        op_def += """, {initializer})"""
        init = o.value
        py_operand_name = "Parameter"
      else:
        op_def += ")"
        init = []
        py_operand_name = "IgnoredOutput" if o in set(
            self.__last_outs) else o.__class__.__name__

      op = {
          "element_type": ty.type,
          "shape": ty.GetRawShape(),
          "op_name": str(o),
          "operand": py_operand_name,
          "initializer": init
      }
      op_definitions.append(op_def.format(**op))
    return "\n".join(op_definitions)


if __name__ == "__main__":
  (model, example, number) = import_source()
  s = slicing(int(number))

  with SmartOpen(model) as model_file:
    spec_file = " (from: %s)" % (test_generator.FileNames.specFile)
    print("# Generated file%s. Do not edit" % (spec_file), file=model_file)
    print("model = Model()", file=model_file)
    # slicing tool only support one single model per spec file
    model = Model.models[0].Compile()
    for op in model.operations:
        s.format_as_py_op(op)
    print(s.format_operands(model), file=model_file)
    s.dump(model_file)
  with SmartOpen(example) as example_file:
    s.dump_example(example_file)