/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Provides C++ classes to more easily use the Neural Networks API.
// TODO(b/117845862): this should be auto generated from NeuralNetworksWrapper.h.
#ifndef ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
#define ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H
#include "NeuralNetworks.h"
#include "NeuralNetworksWrapper.h"
#include "NeuralNetworksWrapperExtensions.h"
#include <math.h>
#include <optional>
#include <string>
#include <vector>
namespace android {
namespace nn {
namespace test_wrapper {
using wrapper::Event;
using wrapper::ExecutePreference;
using wrapper::ExtensionModel;
using wrapper::ExtensionOperandParams;
using wrapper::ExtensionOperandType;
using wrapper::Memory;
using wrapper::Model;
using wrapper::OperandType;
using wrapper::Result;
using wrapper::SymmPerChannelQuantParams;
using wrapper::Type;
class Compilation {
public:
Compilation(const Model* model) {
int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
if (result != 0) {
// TODO Handle the error
}
}
Compilation() {}
~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
// Disallow copy semantics to ensure the runtime object can only be freed
// once. Copy semantics could be enabled if some sort of reference counting
// or deep-copy system for runtime objects is added later.
Compilation(const Compilation&) = delete;
Compilation& operator=(const Compilation&) = delete;
// Move semantics to remove access to the runtime object from the wrapper
// object that is being moved. This ensures the runtime object will be
// freed only once.
Compilation(Compilation&& other) { *this = std::move(other); }
Compilation& operator=(Compilation&& other) {
if (this != &other) {
ANeuralNetworksCompilation_free(mCompilation);
mCompilation = other.mCompilation;
other.mCompilation = nullptr;
}
return *this;
}
Result setPreference(ExecutePreference preference) {
return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
mCompilation, static_cast<int32_t>(preference)));
}
Result setCaching(const std::string& cacheDir, const std::vector<uint8_t>& token) {
if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) {
return Result::BAD_DATA;
}
return static_cast<Result>(ANeuralNetworksCompilation_setCaching(
mCompilation, cacheDir.c_str(), token.data()));
}
Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
protected:
ANeuralNetworksCompilation* mCompilation = nullptr;
};
class Execution {
public:
Execution(const Compilation* compilation) : mCompilation(compilation->getHandle()) {
int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
if (result != 0) {
// TODO Handle the error
}
}
~Execution() { ANeuralNetworksExecution_free(mExecution); }
// Disallow copy semantics to ensure the runtime object can only be freed
// once. Copy semantics could be enabled if some sort of reference counting
// or deep-copy system for runtime objects is added later.
Execution(const Execution&) = delete;
Execution& operator=(const Execution&) = delete;
// Move semantics to remove access to the runtime object from the wrapper
// object that is being moved. This ensures the runtime object will be
// freed only once.
Execution(Execution&& other) { *this = std::move(other); }
Execution& operator=(Execution&& other) {
if (this != &other) {
ANeuralNetworksExecution_free(mExecution);
mCompilation = other.mCompilation;
other.mCompilation = nullptr;
mExecution = other.mExecution;
other.mExecution = nullptr;
}
return *this;
}
Result setInput(uint32_t index, const void* buffer, size_t length,
const ANeuralNetworksOperandType* type = nullptr) {
return static_cast<Result>(
ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
}
Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
mExecution, index, type, memory->get(), offset, length));
}
Result setOutput(uint32_t index, void* buffer, size_t length,
const ANeuralNetworksOperandType* type = nullptr) {
return static_cast<Result>(
ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
}
Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
mExecution, index, type, memory->get(), offset, length));
}
Result startCompute(Event* event) {
ANeuralNetworksEvent* ev = nullptr;
Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
event->set(ev);
return result;
}
Result compute() {
switch (mComputeMode) {
case ComputeMode::SYNC: {
return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
}
case ComputeMode::ASYNC: {
ANeuralNetworksEvent* event = nullptr;
Result result = static_cast<Result>(
ANeuralNetworksExecution_startCompute(mExecution, &event));
if (result != Result::NO_ERROR) {
return result;
}
// TODO how to manage the lifetime of events when multiple waiters is not
// clear.
result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
ANeuralNetworksEvent_free(event);
return result;
}
case ComputeMode::BURST: {
ANeuralNetworksBurst* burst = nullptr;
Result result =
static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
if (result != Result::NO_ERROR) {
return result;
}
result = static_cast<Result>(
ANeuralNetworksExecution_burstCompute(mExecution, burst));
ANeuralNetworksBurst_free(burst);
return result;
}
}
return Result::BAD_DATA;
}
// By default, compute() uses the synchronous API. setComputeMode() can be
// used to change the behavior of compute() to either:
// - use the asynchronous API and then wait for computation to complete
// or
// - use the burst API
// Returns the previous ComputeMode.
enum class ComputeMode { SYNC, ASYNC, BURST };
static ComputeMode setComputeMode(ComputeMode mode) {
ComputeMode oldComputeMode = mComputeMode;
mComputeMode = mode;
return oldComputeMode;
}
Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
uint32_t rank = 0;
Result result = static_cast<Result>(
ANeuralNetworksExecution_getOutputOperandRank(mExecution, index, &rank));
dimensions->resize(rank);
if ((result != Result::NO_ERROR && result != Result::OUTPUT_INSUFFICIENT_SIZE) ||
rank == 0) {
return result;
}
result = static_cast<Result>(ANeuralNetworksExecution_getOutputOperandDimensions(
mExecution, index, dimensions->data()));
return result;
}
private:
ANeuralNetworksCompilation* mCompilation = nullptr;
ANeuralNetworksExecution* mExecution = nullptr;
// Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
static ComputeMode mComputeMode;
};
} // namespace test_wrapper
} // namespace nn
} // namespace android
#endif // ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H