/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef ANDROID_ML_NN_RUNTIME_MEMORY_H
#define ANDROID_ML_NN_RUNTIME_MEMORY_H

#include "NeuralNetworks.h"
#include "Utils.h"

#include <cutils/native_handle.h>
#include <sys/mman.h>
#include <mutex>
#include <unordered_map>
#include "vndk/hardware_buffer.h"

namespace android {
namespace nn {

class ExecutionBurstController;
class ModelBuilder;

// Represents a memory region.
class Memory {
   public:
    Memory() {}
    virtual ~Memory();

    // Disallow copy semantics to ensure the runtime object can only be freed
    // once. Copy semantics could be enabled if some sort of reference counting
    // or deep-copy system for runtime objects is added later.
    Memory(const Memory&) = delete;
    Memory& operator=(const Memory&) = delete;

    // Creates a shared memory object of the size specified in bytes.
    int create(uint32_t size);

    hardware::hidl_memory getHidlMemory() const { return mHidlMemory; }

    // Returns a pointer to the underlying memory of this memory object.
    // The function will fail if the memory is not CPU accessible and nullptr
    // will be returned.
    virtual int getPointer(uint8_t** buffer) const {
        *buffer = static_cast<uint8_t*>(static_cast<void*>(mMemory->getPointer()));
        if (*buffer == nullptr) {
            return ANEURALNETWORKS_BAD_DATA;
        }
        return ANEURALNETWORKS_NO_ERROR;
    }

    virtual bool validateSize(uint32_t offset, uint32_t length) const;

    // Unique key representing this memory object.
    intptr_t getKey() const;

    // Marks a burst object as currently using this memory. When this
    // memory object is destroyed, it will automatically free this memory from
    // the bursts' memory cache.
    void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const;

   protected:
    // The hidl_memory handle for this shared memory.  We will pass this value when
    // communicating with the drivers.
    hardware::hidl_memory mHidlMemory;
    sp<IMemory> mMemory;

    mutable std::mutex mMutex;
    // mUsedBy is essentially a set of burst objects which use this Memory
    // object. However, std::weak_ptr does not have comparison operations nor a
    // std::hash implementation. This is because it is either a valid pointer
    // (non-null) if the shared object is still alive, or it is null if the
    // object has been freed. To circumvent this, mUsedBy is a map with the raw
    // pointer as the key and the weak_ptr as the value.
    mutable std::unordered_map<const ExecutionBurstController*,
                               std::weak_ptr<ExecutionBurstController>>
            mUsedBy;
};

class MemoryFd : public Memory {
   public:
    MemoryFd() {}
    ~MemoryFd() override;

    // Disallow copy semantics to ensure the runtime object can only be freed
    // once. Copy semantics could be enabled if some sort of reference counting
    // or deep-copy system for runtime objects is added later.
    MemoryFd(const MemoryFd&) = delete;
    MemoryFd& operator=(const MemoryFd&) = delete;

    // Create the native_handle based on input size, prot, and fd.
    // Existing native_handle will be deleted, and mHidlMemory will wrap
    // the newly created native_handle.
    int set(size_t size, int prot, int fd, size_t offset);

    int getPointer(uint8_t** buffer) const override;

   private:
    native_handle_t* mHandle = nullptr;
    mutable uint8_t* mMapping = nullptr;
};

// TODO(miaowang): move function definitions to Memory.cpp
class MemoryAHWB : public Memory {
   public:
    MemoryAHWB() {}
    ~MemoryAHWB() override{};

    // Disallow copy semantics to ensure the runtime object can only be freed
    // once. Copy semantics could be enabled if some sort of reference counting
    // or deep-copy system for runtime objects is added later.
    MemoryAHWB(const MemoryAHWB&) = delete;
    MemoryAHWB& operator=(const MemoryAHWB&) = delete;

    // Keep track of the provided AHardwareBuffer handle.
    int set(const AHardwareBuffer* ahwb) {
        AHardwareBuffer_describe(ahwb, &mBufferDesc);
        const native_handle_t* handle = AHardwareBuffer_getNativeHandle(ahwb);
        mHardwareBuffer = ahwb;
        if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) {
            mHidlMemory = hidl_memory("hardware_buffer_blob", handle, mBufferDesc.width);
        } else {
            // memory size is not used.
            mHidlMemory = hidl_memory("hardware_buffer", handle, 0);
        }
        return ANEURALNETWORKS_NO_ERROR;
    };

    int getPointer(uint8_t** buffer) const override {
        *buffer = nullptr;
        return ANEURALNETWORKS_BAD_DATA;
    };

    // validateSize should only be called for blob mode AHardwareBuffer.
    // Calling it on non-blob mode AHardwareBuffer will result in an error.
    // TODO(miaowang): consider separate blob and non-blob into different classes.
    bool validateSize(uint32_t offset, uint32_t length) const override {
        if (mHardwareBuffer == nullptr) {
            LOG(ERROR) << "MemoryAHWB has not been initialized.";
            return false;
        }
        // validateSize should only be called on BLOB mode buffer.
        if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) {
            if (offset + length > mBufferDesc.width) {
                LOG(ERROR) << "Request size larger than the memory size.";
                return false;
            } else {
                return true;
            }
        } else {
            LOG(ERROR) << "Invalid AHARDWAREBUFFER_FORMAT, must be AHARDWAREBUFFER_FORMAT_BLOB.";
            return false;
        }
    }

   private:
    const AHardwareBuffer* mHardwareBuffer = nullptr;
    AHardwareBuffer_Desc mBufferDesc;
};

// A utility class to accumulate mulitple Memory objects and assign each
// a distinct index number, starting with 0.
//
// The user of this class is responsible for avoiding concurrent calls
// to this class from multiple threads.
class MemoryTracker {
   private:
    // The vector of Memory pointers we are building.
    std::vector<const Memory*> mMemories;
    // A faster way to see if we already have a memory than doing find().
    std::unordered_map<const Memory*, uint32_t> mKnown;

   public:
    // Adds the memory, if it does not already exists.  Returns its index.
    // The memories should survive the tracker.
    uint32_t add(const Memory* memory);
    // Returns the number of memories contained.
    uint32_t size() const { return static_cast<uint32_t>(mKnown.size()); }
    // Returns the ith memory.
    const Memory* operator[](size_t i) const { return mMemories[i]; }
    // Iteration
    decltype(mMemories.begin()) begin() { return mMemories.begin(); }
    decltype(mMemories.end()) end() { return mMemories.end(); }
};

}  // namespace nn
}  // namespace android

#endif  // ANDROID_ML_NN_RUNTIME_MEMORY_H