/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "ExecutionBurstController"
#include "ExecutionBurstController.h"
#include <android-base/logging.h>
#include <cstring>
#include <limits>
#include <string>
#include "Tracing.h"
namespace android::nn {
namespace {
using ::android::hardware::MQDescriptorSync;
using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
std::numeric_limits<uint64_t>::max()};
class BurstContextDeathHandler : public hardware::hidl_death_recipient {
public:
using Callback = std::function<void()>;
BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
CHECK(onDeathCallback != nullptr);
}
void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
mOnDeathCallback();
}
private:
const Callback mOnDeathCallback;
};
} // anonymous namespace
// serialize a request into a packet
std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure,
const std::vector<int32_t>& slots) {
// count how many elements need to be sent for a request
size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
for (const auto& input : request.inputs) {
count += input.dimensions.size();
}
for (const auto& output : request.outputs) {
count += output.dimensions.size();
}
// create buffer to temporarily store elements
std::vector<FmqRequestDatum> data;
data.reserve(count);
// package packetInfo
{
FmqRequestDatum datum;
datum.packetInformation(
{/*.packetSize=*/static_cast<uint32_t>(count),
/*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
/*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
/*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
data.push_back(datum);
}
// package input data
for (const auto& input : request.inputs) {
// package operand information
FmqRequestDatum datum;
datum.inputOperandInformation(
{/*.hasNoValue=*/input.hasNoValue,
/*.location=*/input.location,
/*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
data.push_back(datum);
// package operand dimensions
for (uint32_t dimension : input.dimensions) {
FmqRequestDatum datum;
datum.inputOperandDimensionValue(dimension);
data.push_back(datum);
}
}
// package output data
for (const auto& output : request.outputs) {
// package operand information
FmqRequestDatum datum;
datum.outputOperandInformation(
{/*.hasNoValue=*/output.hasNoValue,
/*.location=*/output.location,
/*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
data.push_back(datum);
// package operand dimensions
for (uint32_t dimension : output.dimensions) {
FmqRequestDatum datum;
datum.outputOperandDimensionValue(dimension);
data.push_back(datum);
}
}
// package pool identifier
for (int32_t slot : slots) {
FmqRequestDatum datum;
datum.poolIdentifier(slot);
data.push_back(datum);
}
// package measureTiming
{
FmqRequestDatum datum;
datum.measureTiming(measure);
data.push_back(datum);
}
// return packet
return data;
}
// deserialize a packet into the result
std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
const std::vector<FmqResultDatum>& data) {
using discriminator = FmqResultDatum::hidl_discriminator;
std::vector<OutputShape> outputShapes;
size_t index = 0;
// validate packet information
if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage packet information
const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
index++;
const uint32_t packetSize = packetInfo.packetSize;
const ErrorStatus errorStatus = packetInfo.errorStatus;
const uint32_t numberOfOperands = packetInfo.numberOfOperands;
// verify packet size
if (data.size() != packetSize) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage operands
for (size_t operand = 0; operand < numberOfOperands; ++operand) {
// validate operand information
if (data[index].getDiscriminator() != discriminator::operandInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage operand information
const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
index++;
const bool isSufficient = operandInfo.isSufficient;
const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
// unpackage operand dimensions
std::vector<uint32_t> dimensions;
dimensions.reserve(numberOfDimensions);
for (size_t i = 0; i < numberOfDimensions; ++i) {
// validate dimension
if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage dimension
const uint32_t dimension = data[index].operandDimensionValue();
index++;
// store result
dimensions.push_back(dimension);
}
// store result
outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
}
// validate execution timing
if (data[index].getDiscriminator() != discriminator::executionTiming) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// unpackage execution timing
const Timing timing = data[index].executionTiming();
index++;
// validate packet information
if (index != packetSize) {
LOG(ERROR) << "FMQ Result packet ill-formed";
return std::nullopt;
}
// return result
return std::make_tuple(errorStatus, std::move(outputShapes), timing);
}
std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
ResultChannelReceiver::create(size_t channelLength, bool blocking) {
std::unique_ptr<FmqResultChannel> fmqResultChannel =
std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/blocking);
if (!fmqResultChannel->isValid()) {
LOG(ERROR) << "Unable to create ResultChannelReceiver";
return {nullptr, nullptr};
}
const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
return std::make_pair(
std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), blocking),
descriptor);
}
ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
bool blocking)
: mFmqResultChannel(std::move(fmqResultChannel)), mBlocking(blocking) {}
std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>>
ResultChannelReceiver::getBlocking() {
const auto packet = getPacketBlocking();
if (!packet) {
return std::nullopt;
}
return deserialize(*packet);
}
void ResultChannelReceiver::invalidate() {
mValid = false;
// force unblock
// ExecutionBurstController waits on a result packet after sending a
// request. If the driver containing ExecutionBurstServer crashes, the
// controller will still be waiting on the futex (assuming mBlocking is
// true). This force unblock wakes up any thread waiting on the futex.
if (mBlocking) {
// TODO: look for a different/better way to signal/notify the futex to
// wake up any thread waiting on it
FmqResultDatum datum;
datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/ErrorStatus::GENERAL_FAILURE,
/*.numberOfOperands=*/0});
mFmqResultChannel->writeBlocking(&datum, 1);
}
}
std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
using discriminator = FmqResultDatum::hidl_discriminator;
if (!mValid) {
return std::nullopt;
}
// wait for result packet and read first element of result packet
FmqResultDatum datum;
bool success = true;
if (mBlocking) {
success = mFmqResultChannel->readBlocking(&datum, 1);
} else {
while ((success = mValid.load(std::memory_order_relaxed)) &&
!mFmqResultChannel->read(&datum, 1)) {
}
}
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
// need to do a blocking wait to wait for more data. This is known because
// in FMQ, all writes are published (made available) atomically. Currently,
// the producer always publishes the entire packet in one function call, so
// if the first element of the packet is available, the remaining elements
// are also available.
const size_t count = mFmqResultChannel->availableToRead();
std::vector<FmqResultDatum> packet(count + 1);
std::memcpy(&packet.front(), &datum, sizeof(datum));
success &= mFmqResultChannel->read(packet.data() + 1, count);
if (!mValid) {
return std::nullopt;
}
// ensure packet was successfully received
if (!success) {
LOG(ERROR) << "Error receiving packet";
return std::nullopt;
}
return std::make_optional(std::move(packet));
}
std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
RequestChannelSender::create(size_t channelLength, bool blocking) {
std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/blocking);
if (!fmqRequestChannel->isValid()) {
LOG(ERROR) << "Unable to create RequestChannelSender";
return {nullptr, nullptr};
}
const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
return std::make_pair(
std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel), blocking),
descriptor);
}
RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
bool blocking)
: mFmqRequestChannel(std::move(fmqRequestChannel)), mBlocking(blocking) {}
bool RequestChannelSender::send(const Request& request, MeasureTiming measure,
const std::vector<int32_t>& slots) {
const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
return sendPacket(serialized);
}
bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
if (!mValid) {
return false;
}
if (packet.size() > mFmqRequestChannel->availableToWrite()) {
LOG(ERROR)
<< "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
return false;
}
if (mBlocking) {
return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
} else {
return mFmqRequestChannel->write(packet.data(), packet.size());
}
}
void RequestChannelSender::invalidate() {
mValid = false;
}
Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
const hidl_vec<int32_t>& slots, getMemories_cb cb) {
std::lock_guard<std::mutex> guard(mMutex);
// get all memories
hidl_vec<hidl_memory> memories(slots.size());
std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
});
// ensure all memories are valid
if (!std::all_of(memories.begin(), memories.end(),
[](const hidl_memory& memory) { return memory.valid(); })) {
cb(ErrorStatus::INVALID_ARGUMENT, {});
return Void();
}
// return successful
cb(ErrorStatus::NONE, std::move(memories));
return Void();
}
std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
std::lock_guard<std::mutex> guard(mMutex);
// retrieve (or bind) all slots corresponding to memories
std::vector<int32_t> slots;
slots.reserve(memories.size());
for (size_t i = 0; i < memories.size(); ++i) {
slots.push_back(getSlotLocked(memories[i], keys[i]));
}
return slots;
}
std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
intptr_t key) {
std::lock_guard<std::mutex> guard(mMutex);
auto iter = mMemoryIdToSlot.find(key);
if (iter == mMemoryIdToSlot.end()) {
return {false, 0};
}
const int32_t slot = iter->second;
mMemoryIdToSlot.erase(key);
mMemoryCache[slot] = {};
mFreeSlots.push(slot);
return {true, slot};
}
int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
intptr_t key) {
auto iter = mMemoryIdToSlot.find(key);
if (iter == mMemoryIdToSlot.end()) {
const int32_t slot = allocateSlotLocked();
mMemoryIdToSlot[key] = slot;
mMemoryCache[slot] = memory;
return slot;
} else {
const int32_t slot = iter->second;
return slot;
}
}
int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
// if there is a free slot, use it
if (mFreeSlots.size() > 0) {
const int32_t slot = mFreeSlots.top();
mFreeSlots.pop();
return slot;
}
// otherwise use a slot for the first time
CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
mMemoryCache.emplace_back();
return slot;
}
std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
const sp<IPreparedModel>& preparedModel, bool blocking) {
// check inputs
if (preparedModel == nullptr) {
LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
return nullptr;
}
// create callback object
sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
// create FMQ objects
auto [requestChannelSenderTemp, requestChannelDescriptor] =
RequestChannelSender::create(kExecutionBurstChannelLength, blocking);
auto [resultChannelReceiverTemp, resultChannelDescriptor] =
ResultChannelReceiver::create(kExecutionBurstChannelLength, blocking);
std::shared_ptr<RequestChannelSender> requestChannelSender =
std::move(requestChannelSenderTemp);
std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
std::move(resultChannelReceiverTemp);
// check FMQ objects
if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
!resultChannelDescriptor) {
LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
return nullptr;
}
// configure burst
ErrorStatus errorStatus;
sp<IBurstContext> burstContext;
const Return<void> ret = preparedModel->configureExecutionBurst(
callback, *requestChannelDescriptor, *resultChannelDescriptor,
[&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
errorStatus = status;
burstContext = context;
});
// check burst
if (!ret.isOk()) {
LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
<< ret.description();
return nullptr;
}
if (errorStatus != ErrorStatus::NONE) {
LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
<< toString(errorStatus);
return nullptr;
}
if (burstContext == nullptr) {
LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
return nullptr;
}
// create death handler object
BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
resultChannelReceiver] {
requestChannelSender->invalidate();
resultChannelReceiver->invalidate();
};
const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
// linkToDeath registers a callback that will be invoked on service death to
// proactively handle service crashes. If the linkToDeath call fails,
// asynchronous calls are susceptible to hangs if the service crashes before
// providing the response.
const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
"for the IBurstContext object.";
return nullptr;
}
// make and return controller
return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
burstContext, callback, deathHandler);
}
ExecutionBurstController::ExecutionBurstController(
const std::shared_ptr<RequestChannelSender>& requestChannelSender,
const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
const sp<hardware::hidl_death_recipient>& deathHandler)
: mRequestChannelSender(requestChannelSender),
mResultChannelReceiver(resultChannelReceiver),
mBurstContext(burstContext),
mMemoryCache(callback),
mDeathHandler(deathHandler) {}
ExecutionBurstController::~ExecutionBurstController() {
// It is safe to ignore any errors resulting from this unlinkToDeath call
// because the ExecutionBurstController object is already being destroyed
// and its underlying IBurstContext object is no longer being used by the NN
// runtime.
if (mDeathHandler) {
mBurstContext->unlinkToDeath(mDeathHandler).isOk();
}
}
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
auto [status, outputShapes, timing, fallback] = tryCompute(request, measure, memoryIds);
(void)fallback; // ignore fallback field
return {status, std::move(outputShapes), timing};
}
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool>
ExecutionBurstController::tryCompute(const Request& request, MeasureTiming measure,
const std::vector<intptr_t>& memoryIds) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
std::lock_guard<std::mutex> guard(mMutex);
// send request packet
const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
const bool success = mRequestChannelSender->send(request, measure, slots);
if (!success) {
LOG(ERROR) << "Error sending FMQ packet";
// only use fallback execution path if the packet could not be sent
return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/true};
}
// get result packet
const auto result = mResultChannelReceiver->getBlocking();
if (!result) {
LOG(ERROR) << "Error retrieving FMQ packet";
// only use fallback execution path if the packet could not be sent
return {ErrorStatus::GENERAL_FAILURE, {}, kNoTiming, /*fallback=*/false};
}
// unpack results and return (only use fallback execution path if the
// packet could not be sent)
auto [status, outputShapes, timing] = std::move(*result);
return {status, std::move(outputShapes), timing, /*fallback=*/false};
}
void ExecutionBurstController::freeMemory(intptr_t key) {
std::lock_guard<std::mutex> guard(mMutex);
bool valid;
int32_t slot;
std::tie(valid, slot) = mMemoryCache->freeMemory(key);
if (valid) {
mBurstContext->freeMemory(slot).isOk();
}
}
} // namespace android::nn