/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <functional>
#include <map>
#include <string>
#include <vector>

#include <google/protobuf/descriptor.h>
#include <google/protobuf/compiler/plugin.h>
#include <google/protobuf/compiler/code_generator.h>
#include <google/protobuf/io/printer.h>
#include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/stubs/strutil.h>

#include "nugget/protobuf/options.pb.h"

using ::google::protobuf::FileDescriptor;
using ::google::protobuf::JoinStrings;
using ::google::protobuf::MethodDescriptor;
using ::google::protobuf::ServiceDescriptor;
using ::google::protobuf::Split;
using ::google::protobuf::SplitStringUsing;
using ::google::protobuf::StripSuffixString;
using ::google::protobuf::compiler::CodeGenerator;
using ::google::protobuf::compiler::OutputDirectory;
using ::google::protobuf::io::Printer;
using ::google::protobuf::io::ZeroCopyOutputStream;

using ::nugget::protobuf::app_id;
using ::nugget::protobuf::request_buffer_size;
using ::nugget::protobuf::response_buffer_size;

namespace {

std::string validateServiceOptions(const ServiceDescriptor& service) {
    if (!service.options().HasExtension(app_id)) {
        return "nugget.protobuf.app_id is not defined for service " + service.name();
    }
    if (!service.options().HasExtension(request_buffer_size)) {
        return "nugget.protobuf.request_buffer_size is not defined for service " + service.name();
    }
    if (!service.options().HasExtension(response_buffer_size)) {
        return "nugget.protobuf.response_buffer_size is not defined for service " + service.name();
    }
    return "";
}

template <typename Descriptor>
std::vector<std::string> Packages(const Descriptor& descriptor) {
    std::vector<std::string> namespaces;
    SplitStringUsing(descriptor.full_name(), ".", &namespaces);
    namespaces.pop_back(); // just take the package
    return namespaces;
}

template <typename Descriptor>
std::string FullyQualifiedIdentifier(const Descriptor& descriptor) {
    const auto namespaces = Packages(descriptor);
    if (namespaces.empty()) {
        return "::" + descriptor.name();
    } else {
        std::string namespace_path;
        JoinStrings(namespaces, "::", &namespace_path);
        return "::" + namespace_path + "::" + descriptor.name();
    }
}

template <typename Descriptor>
std::string FullyQualifiedHeader(const Descriptor& descriptor) {
    const auto packages = Packages(descriptor);
    const auto file = Split(descriptor.file()->name(), "/").back();
    const auto header = StripSuffixString(file, ".proto") + ".pb.h";
    if (packages.empty()) {
        return header;
    } else {
        std::string package_path;
        JoinStrings(packages, "/", &package_path);
        return package_path + "/" + header;
    }
}

template <typename Descriptor>
void OpenNamespaces(Printer& printer, const Descriptor& descriptor) {
    const auto namespaces = Packages(descriptor);
    for (const auto& ns : namespaces) {
        std::map<std::string, std::string> namespaceVars;
        namespaceVars["namespace"] = ns;
        printer.Print(namespaceVars, R"(
namespace $namespace$ {)");
    }
}

template <typename Descriptor>
void CloseNamespaces(Printer& printer, const Descriptor& descriptor) {
    const auto namespaces = Packages(descriptor);
    for (auto it = namespaces.crbegin(); it != namespaces.crend(); ++it) {
        std::map<std::string, std::string> namespaceVars;
        namespaceVars["namespace"] = *it;
        printer.Print(namespaceVars, R"(
} // namespace $namespace$)");
    }
}

void ForEachMethod(const ServiceDescriptor& service,
                   std::function<void(std::map<std::string, std::string>)> handler) {
    for (int i = 0; i < service.method_count(); ++i) {
        const MethodDescriptor& method = *service.method(i);
        std::map<std::string, std::string> vars;
        vars["method_id"] = std::to_string(i);
        vars["method_name"] = method.name();
        vars["method_input_type"] = FullyQualifiedIdentifier(*method.input_type());
        vars["method_output_type"] = FullyQualifiedIdentifier(*method.output_type());
        handler(vars);
    }
}

void GenerateMockClient(Printer& printer, const ServiceDescriptor& service) {
    std::map<std::string, std::string> vars;
    vars["include_guard"] = "PROTOC_GENERATED_MOCK_" + service.name() + "_CLIENT_H";
    vars["service_header"] = service.name() + ".client.h";
    vars["mock_class"] = "Mock" + service.name();
    vars["class"] = service.name();

    printer.Print(vars, R"(
#ifndef $include_guard$
#define $include_guard$

#include <gmock/gmock.h>

#include <$service_header$>)");

    OpenNamespaces(printer, service);

    printer.Print(vars, R"(
struct $mock_class$ : public I$class$ {)");

    ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
        printer.Print(methodVars, R"(
    MOCK_METHOD2($method_name$, uint32_t(const $method_input_type$&, $method_output_type$*));)");
    });

    printer.Print(vars, R"(
};)");

    CloseNamespaces(printer, service);

    printer.Print(vars, R"(
#endif)");
}

void GenerateClientHeader(Printer& printer, const ServiceDescriptor& service) {
    std::map<std::string, std::string> vars;
    vars["include_guard"] = "PROTOC_GENERATED_" + service.name() + "_CLIENT_H";
    vars["protobuf_header"] = FullyQualifiedHeader(service);
    vars["class"] = service.name();
    vars["iface_class"] = "I" + service.name();
    vars["app_id"] = "APP_ID_" + service.options().GetExtension(app_id);

    printer.Print(vars, R"(
#ifndef $include_guard$
#define $include_guard$

#include <application.h>
#include <nos/AppClient.h>
#include <nos/NuggetClientInterface.h>

#include "$protobuf_header$")");

    OpenNamespaces(printer, service);

    // Pure virtual interface to make testing easier
    printer.Print(vars, R"(
class $iface_class$ {
public:
    virtual ~$iface_class$() = default;)");

    ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
        printer.Print(methodVars, R"(
    virtual uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) = 0;)");
    });

    printer.Print(vars, R"(
};)");

    // Implementation of the interface for Nugget
    printer.Print(vars, R"(
class $class$ : public $iface_class$ {
    ::nos::AppClient _app;
public:
    $class$(::nos::NuggetClientInterface& client) : _app{client, $app_id$} {}
    ~$class$() override = default;)");

    ForEachMethod(service, [&](std::map<std::string, std::string> methodVars) {
        printer.Print(methodVars, R"(
    uint32_t $method_name$(const $method_input_type$&, $method_output_type$*) override;)");
    });

    printer.Print(vars, R"(
};)");

    CloseNamespaces(printer, service);

    printer.Print(vars, R"(
#endif)");
}

void GenerateClientSource(Printer& printer, const ServiceDescriptor& service) {
    std::map<std::string, std::string> vars;
    vars["generated_header"] = service.name() + ".client.h";
    vars["class"] = service.name();

    const uint32_t max_request_size = service.options().GetExtension(request_buffer_size);
    const uint32_t max_response_size = service.options().GetExtension(response_buffer_size);
    vars["max_request_size"] = std::to_string(max_request_size);
    vars["max_response_size"] = std::to_string(max_response_size);

    printer.Print(vars, R"(
#include <$generated_header$>

#include <application.h>)");

    OpenNamespaces(printer, service);

    // Methods
    ForEachMethod(service, [&](std::map<std::string, std::string>  methodVars) {
        methodVars.insert(vars.begin(), vars.end());
        printer.Print(methodVars, R"(
uint32_t $class$::$method_name$(const $method_input_type$& request, $method_output_type$* response) {
    const size_t request_size = request.ByteSize();
    if (request_size > $max_request_size$) {
        return APP_ERROR_TOO_MUCH;
    }
    std::vector<uint8_t> buffer(request_size);
    if (!request.SerializeToArray(buffer.data(), buffer.size())) {
        return APP_ERROR_RPC;
    }
    std::vector<uint8_t> responseBuffer;
    if (response != nullptr) {
      responseBuffer.resize($max_response_size$);
    }
    const uint32_t appStatus = _app.Call($method_id$, buffer,
                                         (response != nullptr) ? &responseBuffer : nullptr);
    if (appStatus == APP_SUCCESS && response != nullptr) {
        if (!response->ParseFromArray(responseBuffer.data(), responseBuffer.size())) {
            return APP_ERROR_RPC;
        }
    }
    return appStatus;
})");
    });

    CloseNamespaces(printer, service);
}

// Generator for C++ Nugget service client
class CppNuggetServiceClientGenerator : public CodeGenerator {
public:
    CppNuggetServiceClientGenerator() = default;
    ~CppNuggetServiceClientGenerator() override = default;

    bool Generate(const FileDescriptor* file,
                  const std::string& parameter,
                  OutputDirectory* output_directory,
                  std::string* error) const override {
        for (int i = 0; i < file->service_count(); ++i) {
            const auto& service = *file->service(i);

            *error = validateServiceOptions(service);
            if (!error->empty()) {
                return false;
            }

            if (parameter == "mock") {
                std::unique_ptr<ZeroCopyOutputStream> output{
                        output_directory->Open("Mock" + service.name() + ".client.h")};
                Printer printer(output.get(), '$');
                GenerateMockClient(printer, service);
            } else if (parameter == "header") {
                std::unique_ptr<ZeroCopyOutputStream> output{
                        output_directory->Open(service.name() + ".client.h")};
                Printer printer(output.get(), '$');
                GenerateClientHeader(printer, service);
            } else if (parameter == "source") {
                std::unique_ptr<ZeroCopyOutputStream> output{
                        output_directory->Open(service.name() + ".client.cpp")};
                Printer printer(output.get(), '$');
                GenerateClientSource(printer, service);
            } else {
                *error = "Illegal parameter: must be mock|header|source";
                return false;
            }
        }

        return true;
    }

private:
    GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(CppNuggetServiceClientGenerator);
};

} // namespace

int main(int argc, char* argv[]) {
    GOOGLE_PROTOBUF_VERIFY_VERSION;
    CppNuggetServiceClientGenerator generator;
    return google::protobuf::compiler::PluginMain(argc, argv, &generator);
}