/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H #define ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H #include "HalInterfaces.h" #include "OperationsUtils.h" namespace android { namespace nn { // Encapsulates an operation implementation. struct OperationRegistration { OperationType type; const char* name; // Validates operand types, shapes, and any values known during graph creation. std::function<bool(const IOperationValidationContext*)> validate; // prepare is called when the inputs this operation depends on have been // computed. Typically, prepare does any remaining validation and sets // output shapes via context->setOutputShape(...). std::function<bool(IOperationExecutionContext*)> prepare; // Executes the operation, reading from context->getInputBuffer(...) // and writing to context->getOutputBuffer(...). std::function<bool(IOperationExecutionContext*)> execute; struct Flag { // Whether the operation allows at least one operand to be omitted. bool allowOmittedOperand = false; // Whether the operation allows at least one input operand to be a zero-sized tensor. bool allowZeroSizedInput = false; } flags; OperationRegistration(OperationType type, const char* name, std::function<bool(const IOperationValidationContext*)> validate, std::function<bool(IOperationExecutionContext*)> prepare, std::function<bool(IOperationExecutionContext*)> execute, Flag flags) : type(type), name(name), validate(validate), prepare(prepare), execute(execute), flags(flags) {} }; // A registry of operation implementations. class IOperationResolver { public: virtual const OperationRegistration* findOperation(OperationType operationType) const = 0; virtual ~IOperationResolver() {} }; // A registry of builtin operation implementations. // // Note that some operations bypass BuiltinOperationResolver (b/124041202). // // Usage: // const OperationRegistration* operationRegistration = // BuiltinOperationResolver::get()->findOperation(operationType); // NN_RET_CHECK(operationRegistration != nullptr); // NN_RET_CHECK(operationRegistration->validate != nullptr); // NN_RET_CHECK(operationRegistration->validate(&context)); // class BuiltinOperationResolver : public IOperationResolver { DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver); public: static const BuiltinOperationResolver* get() { static BuiltinOperationResolver instance; return &instance; } const OperationRegistration* findOperation(OperationType operationType) const override; private: BuiltinOperationResolver(); void registerOperation(const OperationRegistration* operationRegistration); const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {}; }; // NN_REGISTER_OPERATION creates OperationRegistration for consumption by // OperationResolver. // // Usage: // (check OperationRegistration::Flag for available fields and default values.) // // - With default flags. // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, // foo_op::prepare, foo_op::execute); // // - With a customized flag. // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true); // // - With multiple customized flags. // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true, // .allowZeroSizedInput = true); // #ifdef NN_INCLUDE_CPU_IMPLEMENTATION #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \ const OperationRegistration* register_##identifier() { \ static OperationRegistration registration(OperationType::identifier, operationName, \ validate, prepare, execute, {__VA_ARGS__}); \ return ®istration; \ } #else // This version ignores CPU execution logic (prepare and execute). // The compiler is supposed to omit that code so that only validation logic // makes it into libneuralnetworks_utils. #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \ ...) \ const OperationRegistration* register_##identifier() { \ static OperationRegistration registration(OperationType::identifier, operationName, \ validate, nullptr, nullptr, {__VA_ARGS__}); \ return ®istration; \ } #endif } // namespace nn } // namespace android #endif // ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H