/**
* Copyright 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "run_tflite.h"
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/lite/kernels/register.h"
#include <android/log.h>
#include <dlfcn.h>
#include <sys/time.h>
#include <cstdio>
#define LOG_TAG "NN_BENCHMARK"
#define FATAL(fmt, ...) \
do { \
__android_log_print(ANDROID_LOG_FATAL, LOG_TAG, fmt, ##__VA_ARGS__); \
assert(false); \
} while (0)
namespace {
long long currentTimeInUsec() {
timeval tv;
gettimeofday(&tv, NULL);
return ((tv.tv_sec * 1000000L) + tv.tv_usec);
}
// Workaround for build systems that make difficult to pick the correct NDK API
// level. NDK tracing methods are dynamically loaded from libandroid.so.
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
typedef void* (*fp_ATrace_endSection)();
struct TraceFunc {
fp_ATrace_beginSection ATrace_beginSection;
fp_ATrace_endSection ATrace_endSection;
};
TraceFunc setupTraceFunc() {
void* lib = dlopen("libandroid.so", RTLD_NOW | RTLD_LOCAL);
if (lib == nullptr) {
FATAL("unable to open libandroid.so");
}
return {
reinterpret_cast<fp_ATrace_beginSection>(
dlsym(lib, "ATrace_beginSection")),
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"))};
}
static TraceFunc kTraceFunc{setupTraceFunc()};
} // namespace
BenchmarkModel* BenchmarkModel::create(const char* modelfile, bool use_nnapi,
bool enable_intermediate_tensors_dump,
const char* nnapi_device_name) {
BenchmarkModel* model = new BenchmarkModel();
if (!model->init(modelfile, use_nnapi, enable_intermediate_tensors_dump,
nnapi_device_name)) {
delete model;
return nullptr;
}
return model;
}
bool BenchmarkModel::init(const char* modelfile, bool use_nnapi,
bool enable_intermediate_tensors_dump,
const char* nnapi_device_name) {
__android_log_print(ANDROID_LOG_INFO, LOG_TAG, "BenchmarkModel %s",
modelfile);
// Memory map the model. NOTE this needs lifetime greater than or equal
// to interpreter context.
mTfliteModel = tflite::FlatBufferModel::BuildFromFile(modelfile);
if (!mTfliteModel) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to load model %s",
modelfile);
return false;
}
tflite::ops::builtin::BuiltinOpResolver resolver;
tflite::InterpreterBuilder(*mTfliteModel, resolver)(&mTfliteInterpreter);
if (!mTfliteInterpreter) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Failed to create TFlite interpreter");
return false;
}
if (enable_intermediate_tensors_dump) {
// Make output of every op a model output. This way we will be able to
// fetch each intermediate tensor when running with delegates.
std::vector<int> outputs;
for (size_t node = 0; node < mTfliteInterpreter->nodes_size(); ++node) {
auto node_outputs =
mTfliteInterpreter->node_and_registration(node)->first.outputs;
outputs.insert(outputs.end(), node_outputs->data,
node_outputs->data + node_outputs->size);
}
mTfliteInterpreter->SetOutputs(outputs);
}
// Allow Fp16 precision for all models
mTfliteInterpreter->SetAllowFp16PrecisionForFp32(true);
if (use_nnapi) {
if (nnapi_device_name != nullptr) {
__android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Running NNAPI on device %s",
nnapi_device_name);
}
if (mTfliteInterpreter->ModifyGraphWithDelegate(
tflite::NnApiDelegate(nnapi_device_name)) != kTfLiteOk) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Failed to initialize NNAPI Delegate");
return false;
}
}
return true;
}
BenchmarkModel::BenchmarkModel() {}
BenchmarkModel::~BenchmarkModel() {}
bool BenchmarkModel::setInput(const uint8_t* dataPtr, size_t length) {
int input = mTfliteInterpreter->inputs()[0];
auto* input_tensor = mTfliteInterpreter->tensor(input);
switch (input_tensor->type) {
case kTfLiteFloat32:
case kTfLiteUInt8: {
void* raw = input_tensor->data.raw;
memcpy(raw, dataPtr, length);
break;
}
default:
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Input tensor type not supported");
return false;
}
return true;
}
void BenchmarkModel::saveInferenceOutput(InferenceResult* result,
int output_index) {
int output = mTfliteInterpreter->outputs()[output_index];
auto* output_tensor = mTfliteInterpreter->tensor(output);
auto& sink = result->inferenceOutputs[output_index];
sink.insert(sink.end(), output_tensor->data.uint8,
output_tensor->data.uint8 + output_tensor->bytes);
}
void BenchmarkModel::getOutputError(const uint8_t* expected_data, size_t length,
InferenceResult* result, int output_index) {
int output = mTfliteInterpreter->outputs()[output_index];
auto* output_tensor = mTfliteInterpreter->tensor(output);
if (output_tensor->bytes != length) {
FATAL("Wrong size of output tensor, expected %zu, is %zu",
output_tensor->bytes, length);
}
size_t elements_count = 0;
float err_sum = 0.0;
float max_error = 0.0;
switch (output_tensor->type) {
case kTfLiteUInt8: {
uint8_t* output_raw = mTfliteInterpreter->typed_tensor<uint8_t>(output);
elements_count = output_tensor->bytes;
for (size_t i = 0; i < output_tensor->bytes; ++i) {
float err = ((float)output_raw[i]) - ((float)expected_data[i]);
if (err > max_error) max_error = err;
err_sum += err * err;
}
break;
}
case kTfLiteFloat32: {
const float* expected = reinterpret_cast<const float*>(expected_data);
float* output_raw = mTfliteInterpreter->typed_tensor<float>(output);
elements_count = output_tensor->bytes / sizeof(float);
for (size_t i = 0; i < output_tensor->bytes / sizeof(float); ++i) {
float err = output_raw[i] - expected[i];
if (err > max_error) max_error = err;
err_sum += err * err;
}
break;
}
default:
FATAL("Output sensor type %d not supported", output_tensor->type);
}
result->meanSquareErrors[output_index] = err_sum / elements_count;
result->maxSingleErrors[output_index] = max_error;
}
bool BenchmarkModel::resizeInputTensors(std::vector<int> shape) {
// The benchmark only expects single input tensor, hardcoded as 0.
int input = mTfliteInterpreter->inputs()[0];
mTfliteInterpreter->ResizeInputTensor(input, shape);
if (mTfliteInterpreter->AllocateTensors() != kTfLiteOk) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Failed to allocate tensors!");
return false;
}
return true;
}
bool BenchmarkModel::runInference() {
auto status = mTfliteInterpreter->Invoke();
if (status != kTfLiteOk) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Failed to invoke: %d!",
(int)status);
return false;
}
return true;
}
bool BenchmarkModel::resetStates() {
auto status = mTfliteInterpreter->ResetVariableTensors();
if (status != kTfLiteOk) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Failed to reset variable tensors: %d!", (int)status);
return false;
}
return true;
}
bool BenchmarkModel::benchmark(
const std::vector<InferenceInOutSequence>& inOutData,
int seqInferencesMaxCount, float timeout, int flags,
std::vector<InferenceResult>* results) {
if (inOutData.empty()) {
FATAL("Input/output vector is empty");
}
float inferenceTotal = 0.0;
for (int seqInferenceIndex = 0; seqInferenceIndex < seqInferencesMaxCount;
++seqInferenceIndex) {
resetStates();
const int inputOutputSequenceIndex = seqInferenceIndex % inOutData.size();
const InferenceInOutSequence& seq = inOutData[inputOutputSequenceIndex];
for (int i = 0; i < seq.size(); ++i) {
const InferenceInOut& data = seq[i];
// For NNAPI systrace usage documentation, see
// frameworks/ml/nn/common/include/Tracing.h.
kTraceFunc.ATrace_beginSection("[NN_LA_PE]BenchmarkModel::benchmark");
kTraceFunc.ATrace_beginSection("[NN_LA_PIO]BenchmarkModel::input");
if (data.input) {
setInput(data.input, data.input_size);
} else {
int input = mTfliteInterpreter->inputs()[0];
auto* input_tensor = mTfliteInterpreter->tensor(input);
if (!data.createInput((uint8_t*)input_tensor->data.raw,
input_tensor->bytes)) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Input creation %d failed", i);
return false;
}
}
kTraceFunc.ATrace_endSection();
long long startTime = currentTimeInUsec();
const bool success = runInference();
kTraceFunc.ATrace_endSection();
long long endTime = currentTimeInUsec();
if (!success) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Inference %d failed",
i);
return false;
}
float inferenceTime =
static_cast<float>(endTime - startTime) / 1000000.0f;
size_t outputsCount = mTfliteInterpreter->outputs().size();
InferenceResult result{
inferenceTime, {}, {}, {}, inputOutputSequenceIndex, i};
result.meanSquareErrors.resize(outputsCount);
result.maxSingleErrors.resize(outputsCount);
result.inferenceOutputs.resize(outputsCount);
if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
if (outputsCount != data.outputs.size()) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
"Golden/actual outputs (%zu/%zu) count mismatch",
data.outputs.size(), outputsCount);
return false;
}
for (int j = 0; j < outputsCount; ++j) {
getOutputError(data.outputs[j].ptr, data.outputs[j].size, &result, j);
}
}
if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
for (int j = 0; j < outputsCount; ++j) {
saveInferenceOutput(&result, j);
}
}
results->push_back(result);
inferenceTotal += inferenceTime;
}
// Timeout?
if (timeout > 0.001 && inferenceTotal > timeout) {
return true;
}
}
return true;
}
bool BenchmarkModel::dumpAllLayers(
const char* path, const std::vector<InferenceInOutSequence>& inOutData) {
if (inOutData.empty()) {
FATAL("Input/output vector is empty");
}
for (int seqInferenceIndex = 0; seqInferenceIndex < inOutData.size();
++seqInferenceIndex) {
resetStates();
const InferenceInOutSequence& seq = inOutData[seqInferenceIndex];
for (int i = 0; i < seq.size(); ++i) {
const InferenceInOut& data = seq[i];
setInput(data.input, data.input_size);
const bool success = runInference();
if (!success) {
__android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Inference %d failed",
i);
return false;
}
for (int tensor = 0; tensor < mTfliteInterpreter->tensors_size();
++tensor) {
auto* output_tensor = mTfliteInterpreter->tensor(tensor);
if (output_tensor->data.raw == nullptr) {
continue;
}
char fullpath[1024];
snprintf(fullpath, 1024, "%s/dump_%.3d_seq_%.3d_tensor_%.3d", path,
seqInferenceIndex, i, tensor);
FILE* f = fopen(fullpath, "wb");
fwrite(output_tensor->data.raw, output_tensor->bytes, 1, f);
fclose(f);
}
}
}
return true;
}