/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_ML_NN_RUNTIME_MEMORY_H #define ANDROID_ML_NN_RUNTIME_MEMORY_H #include "NeuralNetworks.h" #include "Utils.h" #include <cutils/native_handle.h> #include <sys/mman.h> #include <mutex> #include <unordered_map> #include "vndk/hardware_buffer.h" namespace android { namespace nn { class ExecutionBurstController; class ModelBuilder; // Represents a memory region. class Memory { public: Memory() {} virtual ~Memory(); // Disallow copy semantics to ensure the runtime object can only be freed // once. Copy semantics could be enabled if some sort of reference counting // or deep-copy system for runtime objects is added later. Memory(const Memory&) = delete; Memory& operator=(const Memory&) = delete; // Creates a shared memory object of the size specified in bytes. int create(uint32_t size); hardware::hidl_memory getHidlMemory() const { return mHidlMemory; } // Returns a pointer to the underlying memory of this memory object. // The function will fail if the memory is not CPU accessible and nullptr // will be returned. virtual int getPointer(uint8_t** buffer) const { *buffer = static_cast<uint8_t*>(static_cast<void*>(mMemory->getPointer())); if (*buffer == nullptr) { return ANEURALNETWORKS_BAD_DATA; } return ANEURALNETWORKS_NO_ERROR; } virtual bool validateSize(uint32_t offset, uint32_t length) const; // Unique key representing this memory object. intptr_t getKey() const; // Marks a burst object as currently using this memory. When this // memory object is destroyed, it will automatically free this memory from // the bursts' memory cache. void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const; protected: // The hidl_memory handle for this shared memory. We will pass this value when // communicating with the drivers. hardware::hidl_memory mHidlMemory; sp<IMemory> mMemory; mutable std::mutex mMutex; // mUsedBy is essentially a set of burst objects which use this Memory // object. However, std::weak_ptr does not have comparison operations nor a // std::hash implementation. This is because it is either a valid pointer // (non-null) if the shared object is still alive, or it is null if the // object has been freed. To circumvent this, mUsedBy is a map with the raw // pointer as the key and the weak_ptr as the value. mutable std::unordered_map<const ExecutionBurstController*, std::weak_ptr<ExecutionBurstController>> mUsedBy; }; class MemoryFd : public Memory { public: MemoryFd() {} ~MemoryFd() override; // Disallow copy semantics to ensure the runtime object can only be freed // once. Copy semantics could be enabled if some sort of reference counting // or deep-copy system for runtime objects is added later. MemoryFd(const MemoryFd&) = delete; MemoryFd& operator=(const MemoryFd&) = delete; // Create the native_handle based on input size, prot, and fd. // Existing native_handle will be deleted, and mHidlMemory will wrap // the newly created native_handle. int set(size_t size, int prot, int fd, size_t offset); int getPointer(uint8_t** buffer) const override; private: native_handle_t* mHandle = nullptr; mutable uint8_t* mMapping = nullptr; }; // TODO(miaowang): move function definitions to Memory.cpp class MemoryAHWB : public Memory { public: MemoryAHWB() {} ~MemoryAHWB() override{}; // Disallow copy semantics to ensure the runtime object can only be freed // once. Copy semantics could be enabled if some sort of reference counting // or deep-copy system for runtime objects is added later. MemoryAHWB(const MemoryAHWB&) = delete; MemoryAHWB& operator=(const MemoryAHWB&) = delete; // Keep track of the provided AHardwareBuffer handle. int set(const AHardwareBuffer* ahwb) { AHardwareBuffer_describe(ahwb, &mBufferDesc); const native_handle_t* handle = AHardwareBuffer_getNativeHandle(ahwb); mHardwareBuffer = ahwb; if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) { mHidlMemory = hidl_memory("hardware_buffer_blob", handle, mBufferDesc.width); } else { // memory size is not used. mHidlMemory = hidl_memory("hardware_buffer", handle, 0); } return ANEURALNETWORKS_NO_ERROR; }; int getPointer(uint8_t** buffer) const override { *buffer = nullptr; return ANEURALNETWORKS_BAD_DATA; }; // validateSize should only be called for blob mode AHardwareBuffer. // Calling it on non-blob mode AHardwareBuffer will result in an error. // TODO(miaowang): consider separate blob and non-blob into different classes. bool validateSize(uint32_t offset, uint32_t length) const override { if (mHardwareBuffer == nullptr) { LOG(ERROR) << "MemoryAHWB has not been initialized."; return false; } // validateSize should only be called on BLOB mode buffer. if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) { if (offset + length > mBufferDesc.width) { LOG(ERROR) << "Request size larger than the memory size."; return false; } else { return true; } } else { LOG(ERROR) << "Invalid AHARDWAREBUFFER_FORMAT, must be AHARDWAREBUFFER_FORMAT_BLOB."; return false; } } private: const AHardwareBuffer* mHardwareBuffer = nullptr; AHardwareBuffer_Desc mBufferDesc; }; // A utility class to accumulate mulitple Memory objects and assign each // a distinct index number, starting with 0. // // The user of this class is responsible for avoiding concurrent calls // to this class from multiple threads. class MemoryTracker { private: // The vector of Memory pointers we are building. std::vector<const Memory*> mMemories; // A faster way to see if we already have a memory than doing find(). std::unordered_map<const Memory*, uint32_t> mKnown; public: // Adds the memory, if it does not already exists. Returns its index. // The memories should survive the tracker. uint32_t add(const Memory* memory); // Returns the number of memories contained. uint32_t size() const { return static_cast<uint32_t>(mKnown.size()); } // Returns the ith memory. const Memory* operator[](size_t i) const { return mMemories[i]; } // Iteration decltype(mMemories.begin()) begin() { return mMemories.begin(); } decltype(mMemories.end()) end() { return mMemories.end(); } }; } // namespace nn } // namespace android #endif // ANDROID_ML_NN_RUNTIME_MEMORY_H