"""Generates the specialized gemm functions.""" import mul_Nx8_Mx8_neon import qnt_Nx8_neon import zip_Nx8_neon _QUANTIZED_8BIT = 'quantized_8bit' _FULL_32BIT = 'full_32bit' _FULL_FLOAT = 'full_float' class Error(Exception): """Module level error.""" class ConfigurationError(Error): """Runtime configuration error.""" def GenerateCommonTempsCountersAndConsts(emitter, rows): emitter.EmitDeclare('const std::int32_t', 'row_chunks', 'm / 3') emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 3') emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8') emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 3') emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size', '(padded_k + 16) * 3') emitter.EmitDeclare('const std::int32_t', 'zipped_rhs_size', '(padded_k + 16) * n') emitter.EmitDeclare('const std::uint8_t*', 'lhs_chunk', 'lhs') emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs') emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch') emitter.EmitDeclare( 'std::int32_t*', 'zipped_lhs_3_offsets', 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * 3)') if rows is not 0: emitter.EmitDeclare( 'std::int32_t*', 'zipped_lhs_%d_offsets' % rows, 'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k * %d)' % rows) emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs', 'scratch + zipped_chunk_size') emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_chunk', 'zipped_rhs') emitter.EmitDeclare('const std::int32_t', 'result_chunk_stride', 'result_stride * 3') emitter.EmitNewline() def GenerateQuantized8BitTempsCountersAndConsts(emitter, rows): """Generates all the boilerplate variables for the q8 gemm function.""" GenerateCommonTempsCountersAndConsts(emitter, rows) emitter.EmitDeclare('const std::int32_t', 'const_offset', 'lhs_offset * rhs_offset * k + result_offset') emitter.EmitDeclare('const std::int32_t', 'rounding_offset', '(1 << (shift - 1))') emitter.EmitDeclare('std::int32_t*', 'temp_result', 'reinterpret_cast<std::int32_t*>(' 'scratch + zipped_chunk_size + zipped_rhs_size)') emitter.EmitDeclare('std::uint8_t*', 'result_chunk', 'result') emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result') emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes', '((n * 4 + 7) / 8) * 8') emitter.EmitNewline() def GenerateFullTempsCountersAndConsts(emitter, result_type, rows): """Generates all the boilerplate variables for the int32 and float gemms.""" GenerateCommonTempsCountersAndConsts(emitter, rows) emitter.EmitDeclare('const std::int32_t', 'const_offset', 'lhs_offset * rhs_offset * k') emitter.EmitDeclare(result_type, 'result_chunk', 'result') emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result') emitter.EmitDeclare('const std::int32_t', 'mul_result_chunk_stride_bytes', 'result_stride * 4') emitter.EmitNewline() def ZipName(rows, leftovers, aligned): return zip_Nx8_neon.BuildName(rows, leftovers, aligned) def GenerateZipRhs(emitter, aligned, cols, leftovers): """Emits the code responsible for zipping the rhs matrix.""" emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)') emitter.EmitCall( ZipName(3, leftovers, aligned), ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0]) emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size') emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size') emitter.EmitCloseBracket() if cols is not 0: emitter.EmitCall( ZipName(cols, leftovers, aligned), ['rhs_chunk', 'k', 'k', 'zipped_rhs_chunk', 'lhs_offset', 0]) emitter.EmitNewline() def MulName(result_type, lhs_add, rhs_add, rows, cols): return mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, rows, cols) def GetMulParams(result_type): params = ['zipped_lhs', 'zipped_rhs_chunk', 'padded_k', 'mul_result_chunk', 'mul_result_chunk_stride_bytes'] if result_type is 'float': params.append('result_scale') return params def GenerateMulRows(emitter, result, result_type, lhs_add, rhs_add, aligned, rows, cols, leftovers): """Emits code responsible for multiplication of one horizontal lhs strip.""" emitter.EmitCall( ZipName(rows, leftovers, aligned), ['lhs_chunk', 'k', 'k', 'zipped_lhs', 'rhs_offset', 'const_offset']) emitter.EmitAssign('zipped_rhs_chunk', 'zipped_rhs') emitter.EmitAssign('mul_result_chunk', result) emitter.EmitOpenBracket('for (int j = 0; j < col_chunks; ++j)') emitter.EmitCall( MulName(result_type, lhs_add, rhs_add, rows, 3), GetMulParams(result_type)) emitter.EmitAssignIncrement('zipped_rhs_chunk', 'zipped_chunk_size') emitter.EmitAssignIncrement('mul_result_chunk', 3) emitter.EmitCloseBracket() if cols is not 0: emitter.EmitCall( MulName(result_type, lhs_add, rhs_add, rows, cols), GetMulParams(result_type)) def GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers): """Emits code for all lhs strips & leftover rows. Quantize after mul code.""" emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)') GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, 3, cols, leftovers) emitter.EmitCall( qnt_Nx8_neon.BuildMultiQuantizeName(aligned, 3), ['temp_result', 'n', 'mul_result_chunk_stride_bytes', 'zipped_lhs_3_offsets', 'result_chunk', 'result_stride', 'multiplicative_offset', 'rounding_offset', '-shift']) emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size') emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride') emitter.EmitCloseBracket() emitter.EmitNewline() if rows is not 0: GenerateMulRows(emitter, 'temp_result', 'int32', False, True, aligned, rows, cols, leftovers) emitter.EmitCall( qnt_Nx8_neon.BuildMultiQuantizeName(aligned, rows), ['temp_result', 'n', 'mul_result_chunk_stride_bytes', 'zipped_lhs_%d_offsets' % rows, 'result_chunk', 'result_stride', 'multiplicative_offset', 'rounding_offset', '-shift']) def GenerateFullMul(emitter, result_type, aligned, rows, cols, leftovers): emitter.EmitOpenBracket('for (int i = 0; i < row_chunks; ++i)') GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, 3, cols, leftovers) emitter.EmitAssignIncrement('lhs_chunk', 'chunk_size') emitter.EmitAssignIncrement('result_chunk', 'result_chunk_stride') emitter.EmitCloseBracket() emitter.EmitNewline() if rows is not 0: GenerateMulRows(emitter, 'result_chunk', result_type, True, True, aligned, rows, cols, leftovers) def BuildName(output_type, aligned, rows, cols, leftover): name = BuildMainGemmName(output_type) + '_%d_%d_%d' % (rows, cols, leftover) if aligned: name += '_aligned' return name def GetCommonGemmParameters(): return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'm'], ['std::int32_t', 'n'], ['std::int32_t', 'k'], ['std::int32_t', 'lhs_offset'], ['std::int32_t', 'rhs_offset']] def GetGemmParameters(output_type, extra_params=None): """Prepares a (type, parameter) array for the gemm functions.""" if extra_params is None: extra_params = [] params = GetCommonGemmParameters() if output_type is _QUANTIZED_8BIT: params += [['std::int32_t', 'result_offset'], ['std::int32_t', 'multiplicative_offset'], ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']] elif output_type is _FULL_32BIT: params += [['std::int32_t*', 'result']] elif output_type is _FULL_FLOAT: params += [['float', 'result_scale'], ['float*', 'result']] else: raise ConfigurationError('Unsupported output type: %s' % output_type) return params + extra_params def GetStridedGemmParameters(output_type): return GetGemmParameters(output_type, [['std::int32_t', 'result_stride']]) def GenerateGemm(emitter, output_type, aligned, rows, cols, leftovers): """Build one gemm function for given row, col, and depth leftovers.""" emitter.EmitFunctionBeginA( BuildName(output_type, aligned, rows, cols, leftovers), GetStridedGemmParameters(output_type), 'void') emitter.EmitAssert('m %% 3 == %d' % rows) emitter.EmitAssert('n %% 3 == %d' % cols) emitter.EmitAssert('k %% 8 == %d' % leftovers) if output_type is _QUANTIZED_8BIT: GenerateQuantized8BitTempsCountersAndConsts(emitter, rows) GenerateZipRhs(emitter, aligned, cols, leftovers) GenerateQuantized8BitMul(emitter, aligned, rows, cols, leftovers) elif output_type is _FULL_32BIT: GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*', rows) GenerateZipRhs(emitter, aligned, cols, leftovers) GenerateFullMul(emitter, 'int32', aligned, rows, cols, leftovers) elif output_type is _FULL_FLOAT: GenerateFullTempsCountersAndConsts(emitter, 'float*', rows) GenerateZipRhs(emitter, aligned, cols, leftovers) GenerateFullMul(emitter, 'float', aligned, rows, cols, leftovers) else: raise ConfigurationError('Unknown output type: %s' % output_type) emitter.EmitFunctionEnd() def GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers): emitter.EmitCall( emitter.Scope('internal', BuildName(output_type, aligned, m_mod, n_mod, leftovers)), [p for (unused_t, p) in GetStridedGemmParameters(output_type)]) def GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod): """Third level of main switch, choose optimized version on depth leftover.""" emitter.EmitSwitch('k % 8') for leftovers in range(0, 8): emitter.EmitCase(leftovers) emitter.PushIndent() GenerateGemmCall(emitter, output_type, aligned, m_mod, n_mod, leftovers) emitter.EmitBreak() emitter.PopIndent() emitter.EmitSwitchEnd() def GenerateGemmSwitch2(emitter, output_type, aligned, m_mod): """Second level of main switch, choose optimized version on cols leftover.""" emitter.EmitSwitch('n % 3') for n_mod in range(0, 3): emitter.EmitCase(n_mod) emitter.PushIndent() GenerateGemmSwitch3(emitter, output_type, aligned, m_mod, n_mod) emitter.EmitBreak() emitter.PopIndent() emitter.EmitSwitchEnd() def GenerateGemmSwitch1(emitter, output_type, aligned): """First level of main switch, choose optimized version on rows leftover.""" emitter.EmitSwitch('m % 3') for m_mod in range(0, 3): emitter.EmitCase(m_mod) emitter.PushIndent() GenerateGemmSwitch2(emitter, output_type, aligned, m_mod) emitter.EmitBreak() emitter.PopIndent() emitter.EmitSwitchEnd() def BuildMainGemmName(output_type): if output_type is _QUANTIZED_8BIT: return 'gemm_q8' elif output_type is _FULL_32BIT: return 'gemm_i32' elif output_type is _FULL_FLOAT: return 'gemm_f' else: raise ConfigurationError('Unsupported output type: %s' % output_type) def BuildStridedMainGemmName(output_type): return BuildMainGemmName(output_type) + '_strided' def GenerateMainGemmFunction(emitter, output_type): """Emit high level gemm function that switches between optimized versions.""" emitter.EmitFunctionBeginA( BuildStridedMainGemmName(output_type), GetStridedGemmParameters(output_type), 'void') emitter.EmitDeclare('const bool', 'lhs_aligned', '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)') emitter.EmitDeclare('const bool', 'rhs_aligned', '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)') emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)') if output_type is _QUANTIZED_8BIT: emitter.EmitDeclare('const bool', 'result_aligned', '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)') emitter.EmitDeclare('const bool', 'result_stride_aligned', '((result_stride % 8) == 0)') emitter.EmitDeclare('const bool', 'aligned', 'lhs_aligned && rhs_aligned && result_aligned ' '&& k_aligned && result_stride_aligned') else: emitter.EmitDeclare('const bool', 'aligned', 'lhs_aligned && rhs_aligned && k_aligned') emitter.EmitIf('aligned') GenerateGemmSwitch1(emitter, output_type, True) emitter.EmitElse() GenerateGemmSwitch1(emitter, output_type, False) emitter.EmitEndif() emitter.EmitFunctionEnd() def GenerateWrapperGemmFunction(emitter, output_type): emitter.EmitFunctionBeginA( BuildMainGemmName(output_type), GetGemmParameters(output_type), 'void') emitter.EmitCall( BuildStridedMainGemmName(output_type), [p for (unused_t, p) in GetGemmParameters(output_type)] + ['n']) emitter.EmitFunctionEnd() def GenerateInternalFunctions(emitter): """Generate all the functions hidden in the internal namespace.""" for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: for aligned in [True, False]: for rows in range(0, 3): for cols in range(0, 3): for leftover in range(0, 8): GenerateGemm(emitter, output_type, aligned, rows, cols, leftover) emitter.EmitNewline() def GeneratePublicFunctions(emitter): for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]: GenerateMainGemmFunction(emitter, output_type) emitter.EmitNewline() GenerateWrapperGemmFunction(emitter, output_type) emitter.EmitNewline()