/*
 * Copyright 2010, The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "slang_rs_export_func.h"

#include <string>

#include "clang/AST/ASTContext.h"
#include "clang/AST/Decl.h"

#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"

#include "slang_assert.h"
#include "slang_rs_context.h"

namespace slang {

namespace {

// Ensure that the exported function is actually valid
static bool ValidateFuncDecl(slang::RSContext *Context,
                             const clang::FunctionDecl *FD) {
  slangAssert(Context && FD);
  const clang::ASTContext &C = FD->getASTContext();
  if (FD->getReturnType().getCanonicalType() != C.VoidTy) {
    Context->ReportError(
        FD->getLocation(),
        "invokable non-static functions are required to return void");
    return false;
  }
  return true;
}

}  // namespace

RSExportFunc *RSExportFunc::Create(RSContext *Context,
                                   const clang::FunctionDecl *FD) {
  llvm::StringRef Name = FD->getName();
  RSExportFunc *F;

  slangAssert(!Name.empty() && "Function must have a name");

  if (!ValidateFuncDecl(Context, FD)) {
    return nullptr;
  }

  F = new RSExportFunc(Context, Name, FD);

  // Initialize mParamPacketType
  if (FD->getNumParams() <= 0) {
    F->mParamPacketType = nullptr;
  } else {
    clang::ASTContext &Ctx = Context->getASTContext();

    std::string Id = CreateDummyName("helper_func_param", F->getName());

    clang::RecordDecl *RD =
        clang::RecordDecl::Create(Ctx, clang::TTK_Struct,
                                  Ctx.getTranslationUnitDecl(),
                                  clang::SourceLocation(),
                                  clang::SourceLocation(),
                                  &Ctx.Idents.get(Id));

    for (unsigned i = 0; i < FD->getNumParams(); i++) {
      const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
      llvm::StringRef ParamName = PVD->getName();

      if (PVD->hasDefaultArg())
        fprintf(stderr, "Note: parameter '%s' in function '%s' has default "
                        "value which is not supported\n",
                        ParamName.str().c_str(),
                        F->getName().c_str());

      clang::FieldDecl *FD =
          clang::FieldDecl::Create(Ctx,
                                   RD,
                                   clang::SourceLocation(),
                                   clang::SourceLocation(),
                                   PVD->getIdentifier(),
                                   PVD->getOriginalType(),
                                   nullptr,
                                   /* BitWidth = */ nullptr,
                                   /* Mutable = */ false,
                                   /* HasInit = */ clang::ICIS_NoInit);
      RD->addDecl(FD);
    }

    RD->completeDefinition();

    clang::QualType T = Ctx.getTagDeclType(RD);
    slangAssert(!T.isNull());

    RSExportType *ET =
      RSExportType::Create(Context, T.getTypePtr(), NotLegacyKernelArgument);

    if (ET == nullptr) {
      fprintf(stderr, "Failed to export the function %s. There's at least one "
                      "parameter whose type is not supported by the "
                      "reflection\n", F->getName().c_str());
      return nullptr;
    }

    slangAssert((ET->getClass() == RSExportType::ExportClassRecord) &&
           "Parameter packet must be a record");

    F->mParamPacketType = static_cast<RSExportRecordType *>(ET);
  }

  return F;
}

bool
RSExportFunc::checkParameterPacketType(llvm::StructType *ParamTy) const {
  if (ParamTy == nullptr)
    return !hasParam();
  else if (!hasParam())
    return false;

  slangAssert(mParamPacketType != nullptr);

  const RSExportRecordType *ERT = mParamPacketType;
  // must have same number of elements
  if (ERT->getFields().size() != ParamTy->getNumElements())
    return false;

  const llvm::StructLayout *ParamTySL =
      getRSContext()->getDataLayout()->getStructLayout(ParamTy);

  unsigned Index = 0;
  for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
       FE = ERT->fields_end(); FI != FE; FI++, Index++) {
    const RSExportRecordType::Field *F = *FI;

    llvm::Type *T1 = F->getType()->getLLVMType();
    llvm::Type *T2 = ParamTy->getTypeAtIndex(Index);

    // Fast check
    if (T1 == T2)
      continue;

    // Check offset
    size_t T1Offset = F->getOffsetInParent();
    size_t T2Offset = ParamTySL->getElementOffset(Index);

    if (T1Offset != T2Offset)
      return false;

    // Check size
    size_t T1Size = F->getType()->getAllocSize();
    size_t T2Size = getRSContext()->getDataLayout()->getTypeAllocSize(T2);

    if (T1Size != T2Size)
      return false;
  }

  return true;
}

}  // namespace slang