# Copyright (c) 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A number of placeholders and their rules for expansion when used in tests.
These placeholders, when used in spirv_args or expected_* variables of
SpirvTest, have special meanings. In spirv_args, they will be substituted by
the result of instantiate_for_spirv_args(), while in expected_*, by
instantiate_for_expectation(). A TestCase instance will be passed in as
argument to the instantiate_*() methods.
"""
import os
import subprocess
import tempfile
from string import Template
class PlaceHolderException(Exception):
"""Exception class for PlaceHolder."""
pass
class PlaceHolder(object):
"""Base class for placeholders."""
def instantiate_for_spirv_args(self, testcase):
"""Instantiation rules for spirv_args.
This method will be called when the current placeholder appears in
spirv_args.
Returns:
A string to replace the current placeholder in spirv_args.
"""
raise PlaceHolderException('Subclass should implement this function.')
def instantiate_for_expectation(self, testcase):
"""Instantiation rules for expected_*.
This method will be called when the current placeholder appears in
expected_*.
Returns:
A string to replace the current placeholder in expected_*.
"""
raise PlaceHolderException('Subclass should implement this function.')
class FileShader(PlaceHolder):
"""Stands for a shader whose source code is in a file."""
def __init__(self, source, suffix, assembly_substr=None):
assert isinstance(source, str)
assert isinstance(suffix, str)
self.source = source
self.suffix = suffix
self.filename = None
# If provided, this is a substring which is expected to be in
# the disassembly of the module generated from this input file.
self.assembly_substr = assembly_substr
def instantiate_for_spirv_args(self, testcase):
"""Creates a temporary file and writes the source into it.
Returns:
The name of the temporary file.
"""
shader, self.filename = tempfile.mkstemp(
dir=testcase.directory, suffix=self.suffix)
shader_object = os.fdopen(shader, 'w')
shader_object.write(self.source)
shader_object.close()
return self.filename
def instantiate_for_expectation(self, testcase):
assert self.filename is not None
return self.filename
class ConfigFlagsFile(PlaceHolder):
"""Stands for a configuration file for spirv-opt generated out of a string."""
def __init__(self, content, suffix):
assert isinstance(content, str)
assert isinstance(suffix, str)
self.content = content
self.suffix = suffix
self.filename = None
def instantiate_for_spirv_args(self, testcase):
"""Creates a temporary file and writes content into it.
Returns:
The name of the temporary file.
"""
temp_fd, self.filename = tempfile.mkstemp(
dir=testcase.directory, suffix=self.suffix)
fd = os.fdopen(temp_fd, 'w')
fd.write(self.content)
fd.close()
return '-Oconfig=%s' % self.filename
def instantiate_for_expectation(self, testcase):
assert self.filename is not None
return self.filename
class FileSPIRVShader(PlaceHolder):
"""Stands for a source shader file which must be converted to SPIR-V."""
def __init__(self, source, suffix, assembly_substr=None):
assert isinstance(source, str)
assert isinstance(suffix, str)
self.source = source
self.suffix = suffix
self.filename = None
# If provided, this is a substring which is expected to be in
# the disassembly of the module generated from this input file.
self.assembly_substr = assembly_substr
def instantiate_for_spirv_args(self, testcase):
"""Creates a temporary file, writes the source into it and assembles it.
Returns:
The name of the assembled temporary file.
"""
shader, asm_filename = tempfile.mkstemp(
dir=testcase.directory, suffix=self.suffix)
shader_object = os.fdopen(shader, 'w')
shader_object.write(self.source)
shader_object.close()
self.filename = '%s.spv' % asm_filename
cmd = [
testcase.test_manager.assembler_path, asm_filename, '-o', self.filename
]
process = subprocess.Popen(
args=cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=testcase.directory)
output = process.communicate()
assert process.returncode == 0 and not output[0] and not output[1]
return self.filename
def instantiate_for_expectation(self, testcase):
assert self.filename is not None
return self.filename
class StdinShader(PlaceHolder):
"""Stands for a shader whose source code is from stdin."""
def __init__(self, source):
assert isinstance(source, str)
self.source = source
self.filename = None
def instantiate_for_spirv_args(self, testcase):
"""Writes the source code back to the TestCase instance."""
testcase.stdin_shader = self.source
self.filename = '-'
return self.filename
def instantiate_for_expectation(self, testcase):
assert self.filename is not None
return self.filename
class TempFileName(PlaceHolder):
"""Stands for a temporary file's name."""
def __init__(self, filename):
assert isinstance(filename, str)
assert filename != ''
self.filename = filename
def instantiate_for_spirv_args(self, testcase):
return os.path.join(testcase.directory, self.filename)
def instantiate_for_expectation(self, testcase):
return os.path.join(testcase.directory, self.filename)
class SpecializedString(PlaceHolder):
"""Returns a string that has been specialized based on TestCase.
The string is specialized by expanding it as a string.Template
with all of the specialization being done with each $param replaced
by the associated member on TestCase.
"""
def __init__(self, filename):
assert isinstance(filename, str)
assert filename != ''
self.filename = filename
def instantiate_for_spirv_args(self, testcase):
return Template(self.filename).substitute(vars(testcase))
def instantiate_for_expectation(self, testcase):
return Template(self.filename).substitute(vars(testcase))