# -*- coding: utf-8 -*-

#-------------------------------------------------------------------------
# drawElements Quality Program utilities
# --------------------------------------
#
# Copyright 2015 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#-------------------------------------------------------------------------

import sys, logging, re
from lxml import etree
from collections import OrderedDict
from functools import wraps, partial

log = logging.getLogger(__name__)

debug = log.debug
info = log.info
warning = log.warning

def warnElem(elem, fmt, *args):
	warning('%s:%d, %s %s: ' + fmt, elem.base, elem.sourceline, elem.tag, elem.get('name') or '', *args)

class Object(object):
	def __init__(self, **kwargs):
		self.__dict__.update(kwargs)

class Located(Object):
	location = None

class Group(Located): pass
class Enum(Located): pass
class Enums(Located):
	name = None
	comment = None
	enums = None

class Type(Located):
	location = None
	name=None
	definition=None
	api=None
	requires=None

def makeObject(cls, elem, **kwargs):
	kwargs.setdefault('name', elem.get('name'))
	kwargs.setdefault('comment', elem.get('comment'))
	kwargs['location'] = (elem.base, elem.sourceline)
	return cls(**kwargs)

def parseEnum(eEnum):
	return makeObject(
		Enum, eEnum,
		value=eEnum.get('value'),
		type=eEnum.get('type'),
		alias=eEnum.get('alias'))

class Param(Located): pass

class Command(Located):
	name=None
	declaration=None
	type=None
	ptype=None
	group=None
	params=None
	alias=None

class Interface(Object): pass

class Index:
	def __init__(self, items=[], **kwargs):
		self.index = {}
		self.items = []
		self.__dict__.update(kwargs)
		self.update(items)

	def append(self, item):
		keys = self.getkeys(item)
		for key in keys:
			self[key] = item
		self.items.append(item)

	def update(self, items):
		for item in items:
			self.append(item)

	def __iter__(self):
		return iter(self.items)

	def nextkey(self, key):
		raise KeyError

	def getkeys(self, item):
		return []

	def __contains__(self, key):
		return key in self.index

	def __setitem__(self, key, item):
		if key in self.index:
			self.duplicateKey(key, item)
		else:
			self.index[key] = item

	def duplicateKey(self, key, item):
		warning("Duplicate %s: %r", type(item).__name__.lower(), key)

	def __getitem__(self, key):
		try:
			while True:
				try:
					return self.index[key]
				except KeyError:
					pass
				key = self.nextkey(key)
		except KeyError:
			item = self.missingKey(key)
			self.append(item)
			return item

	def missingKey(self, key):
		raise KeyError(key)

	def __len__(self):
		return len(self.items)

class ElemNameIndex(Index):
	def getkeys(self, item):
		return [item.get('name')]

	def duplicateKey(self, key, item):
		warnElem(item, "Duplicate key: %s", key)

class CommandIndex(Index):
	def getkeys(self, item):
		return [item.findtext('proto/name'), item.findtext('alias')]

class NameApiIndex(Index):
	def getkeys(self, item):
		return [(item.get('name'), item.get('api'))]

	def nextkey(self, key):
		if len(key) == 2 and key[1] is not None:
			return key[0], None
		raise KeyError

	def duplicateKey(self, key, item):
		warnElem(item, "Duplicate key: %s", key)

class TypeIndex(NameApiIndex):
	def getkeys(self, item):
		return [(item.get('name') or item.findtext('name'), item.get('api'))]

class EnumIndex(NameApiIndex):
	def getkeys(self, item):
		name, api, alias = (item.get(attrib) for attrib in ['name', 'api', 'alias'])
		return [(name, api)] + ([(alias, api)] if alias is not None else [])

	def duplicateKey(self, (name, api), item):
		if name == item.get('alias'):
			warnElem(item, "Alias already present: %s", name)
		else:
			warnElem(item, "Already present")

class Registry:
	def __init__(self, eRegistry):
		self.types = TypeIndex(eRegistry.findall('types/type'))
		self.groups = ElemNameIndex(eRegistry.findall('groups/group'))
		self.enums = EnumIndex(eRegistry.findall('enums/enum'))
		for eEnum in self.enums:
			groupName = eEnum.get('group')
			if groupName is not None:
				self.groups[groupName] = eEnum
		self.commands = CommandIndex(eRegistry.findall('commands/command'))
		self.features = ElemNameIndex(eRegistry.findall('feature'))
		self.apis = {}
		for eFeature in self.features:
			self.apis.setdefault(eFeature.get('api'), []).append(eFeature)
		for apiFeatures in self.apis.itervalues():
			apiFeatures.sort(key=lambda eFeature: eFeature.get('number'))
		self.extensions = ElemNameIndex(eRegistry.findall('extensions/extension'))
		self.element = eRegistry

	def getFeatures(self, api, checkVersion=None):
		return [eFeature for eFeature in self.apis[api]
				if checkVersion is None or checkVersion(eFeature.get('number'))]

class NameIndex(Index):
	createMissing = None
	kind = "item"

	def getkeys(self, item):
		return [item.name]

	def missingKey(self, key):
		if self.createMissing:
			warning("Reference to implicit %s: %r", self.kind, key)
			return self.createMissing(name=key)
		else:
			raise KeyError

def matchApi(api1, api2):
	return api1 is None or api2 is None or api1 == api2

class Interface(Object):
	pass

def extractAlias(eCommand):
	aliases = eCommand.xpath('alias/@name')
	return aliases[0] if aliases else None

def getExtensionName(eExtension):
	return eExtension.get('name')

def extensionSupports(eExtension, api, profile=None):
	if api == 'gl' and profile == 'core':
		needSupport = 'glcore'
	else:
		needSupport = api
	supporteds = eExtension.get('supported').split('|')
	return needSupport in supporteds

class InterfaceSpec(Object):
	def __init__(self):
		self.enums = set()
		self.types = set()
		self.commands = set()

	def addComponent(self, eComponent):
		if eComponent.tag == 'require':
			def modify(items, item): items.add(item)
		else:
			assert eComponent.tag == 'remove'
			def modify(items, item):
				try:
					items.remove(item)
				except KeyError:
					warning("Tried to remove absent item: %s", item)
		for typeName in eComponent.xpath('type/@name'):
			modify(self.types, typeName)
		for enumName in eComponent.xpath('enum/@name'):
			modify(self.enums, enumName)
		for commandName in eComponent.xpath('command/@name'):
			modify(self.commands, commandName)

	def addComponents(self, elem, api, profile=None):
		for eComponent in elem.xpath('require|remove'):
			cApi = eComponent.get('api')
			cProfile = eComponent.get('profile')
			if (matchApi(api, eComponent.get('api')) and
				matchApi(profile, eComponent.get('profile'))):
				self.addComponent(eComponent)

	def addFeature(self, eFeature, api=None, profile=None, force=False):
		info('Feature %s', eFeature.get('name'))
		if not matchApi(api, eFeature.get('api')):
			if not force: return
			warnElem(eFeature, 'API %s is not supported', api)
		self.addComponents(eFeature, api, profile)

	def addExtension(self, eExtension, api=None, profile=None, force=False):
		if not extensionSupports(eExtension, api, profile):
			if not force: return
			warnElem(eExtension, '%s is not supported in API %s' % (getExtensionName(eExtension), api))
		self.addComponents(eExtension, api, profile)

def createInterface(registry, spec, api=None):
	def parseType(eType):
		# todo: apientry
		#requires = eType.get('requires')
		#if requires is not None:
		#    types[requires]
		return makeObject(
			Type, eType,
			name=eType.get('name') or eType.findtext('name'),
			definition=''.join(eType.xpath('.//text()')),
			api=eType.get('api'),
			requires=eType.get('requires'))

	def createType(name):
		info('Add type %s', name)
		try:
			return parseType(registry.types[name, api])
		except KeyError:
			return Type(name=name)

	def createEnum(enumName):
		info('Add enum %s', enumName)
		return parseEnum(registry.enums[enumName, api])

	def extractPtype(elem):
		ePtype = elem.find('ptype')
		if ePtype is None:
			return None
		return types[ePtype.text]

	def extractGroup(elem):
		groupName = elem.get('group')
		if groupName is None:
			return None
		return groups[groupName]

	def parseParam(eParam):
		return makeObject(
			Param, eParam,
			name=eParam.get('name') or eParam.findtext('name'),
			declaration=''.join(eParam.xpath('.//text()')).strip(),
			type=''.join(eParam.xpath('(.|ptype)/text()')).strip(),
			ptype=extractPtype(eParam),
			group=extractGroup(eParam))

	def createCommand(commandName):
		info('Add command %s', commandName)
		eCmd = registry.commands[commandName]
		eProto = eCmd.find('proto')
		return makeObject(
			Command, eCmd,
			name=eCmd.findtext('proto/name'),
			declaration=''.join(eProto.xpath('.//text()')).strip(),
			type=''.join(eProto.xpath('(.|ptype)/text()')).strip(),
			ptype=extractPtype(eProto),
			group=extractGroup(eProto),
			alias=extractAlias(eCmd),
			params=NameIndex(map(parseParam, eCmd.findall('param'))))

	def createGroup(name):
		info('Add group %s', name)
		try:
			eGroup = registry.groups[name]
		except KeyError:
			return Group(name=name)
		return makeObject(
			Group, eGroup,
			# Missing enums are often from exotic extensions. Don't create dummy entries,
			# just filter them out.
			enums=NameIndex(enums[name] for name in eGroup.xpath('enum/@name')
							if name in enums))

	def sortedIndex(items):
		return NameIndex(sorted(items, key=lambda item: item.location))

	groups = NameIndex(createMissing=createGroup, kind="group")
	types = NameIndex(map(createType, spec.types),
					  createMissing=createType, kind="type")
	enums = NameIndex(map(createEnum, spec.enums),
					  createMissing=Enum, kind="enum")
	commands = NameIndex(map(createCommand, spec.commands),
						 createMissing=Command, kind="command")

	# This is a mess because the registry contains alias chains whose
	# midpoints might not be included in the interface even though
	# endpoints are.
	for command in commands:
		alias = command.alias
		aliasCommand = None
		while alias is not None:
			aliasCommand = registry.commands[alias]
			alias = extractAlias(aliasCommand)
		command.alias = None
		if aliasCommand is not None:
			name = aliasCommand.findtext('proto/name')
			if name in commands:
				command.alias = commands[name]

	return Interface(
		types=sortedIndex(types),
		enums=sortedIndex(enums),
		groups=sortedIndex(groups),
		commands=sortedIndex(commands))


def spec(registry, api, version=None, profile=None, extensionNames=[], protects=[], force=False):
	available = set(protects)
	spec = InterfaceSpec()

	if version is None or version is False:
		def check(v): return False
	elif version is True:
		def check(v): return True
	else:
		def check(v): return v <= version

	for eFeature in registry.getFeatures(api, check):
		spec.addFeature(eFeature, api, profile, force)

	for extName in extensionNames:
		eExtension = registry.extensions[extName]
		protect = eExtension.get('protect')
		if protect is not None and protect not in available:
			warnElem(eExtension, "Unavailable dependency %s", protect)
			if not force:
				continue
		spec.addExtension(eExtension, api, profile, force)
		available.add(extName)

	return spec

def interface(registry, api, **kwargs):
	s = spec(registry, api, **kwargs)
	return createInterface(registry, s, api)

def parse(path):
	return Registry(etree.parse(path))