// Ceres Solver - A fast non-linear least squares minimizer
// Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
// http://code.google.com/p/ceres-solver/
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
//   this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright notice,
//   this list of conditions and the following disclaimer in the documentation
//   and/or other materials provided with the distribution.
// * Neither the name of Google Inc. nor the names of its contributors may be
//   used to endorse or promote products derived from this software without
//   specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Author: sameeragarwal@google.com (Sameer Agarwal)

#include "ceres/linear_least_squares_problems.h"

#include <cstdio>
#include <string>
#include <vector>
#include "ceres/block_sparse_matrix.h"
#include "ceres/block_structure.h"
#include "ceres/casts.h"
#include "ceres/compressed_row_sparse_matrix.h"
#include "ceres/file.h"
#include "ceres/internal/scoped_ptr.h"
#include "ceres/matrix_proto.h"
#include "ceres/stringprintf.h"
#include "ceres/triplet_sparse_matrix.h"
#include "ceres/types.h"
#include "glog/logging.h"

namespace ceres {
namespace internal {

LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromId(int id) {
  switch (id) {
    case 0:
      return LinearLeastSquaresProblem0();
    case 1:
      return LinearLeastSquaresProblem1();
    case 2:
      return LinearLeastSquaresProblem2();
    case 3:
      return LinearLeastSquaresProblem3();
    default:
      LOG(FATAL) << "Unknown problem id requested " << id;
  }
  return NULL;
}

#ifndef CERES_NO_PROTOCOL_BUFFERS
LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile(
    const string& filename) {
  LinearLeastSquaresProblemProto problem_proto;
  {
    string serialized_proto;
    ReadFileToStringOrDie(filename, &serialized_proto);
    CHECK(problem_proto.ParseFromString(serialized_proto));
  }

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
  const SparseMatrixProto& A = problem_proto.a();

  if (A.has_block_matrix()) {
    problem->A.reset(new BlockSparseMatrix(A));
  } else if (A.has_triplet_matrix()) {
    problem->A.reset(new TripletSparseMatrix(A));
  } else {
    problem->A.reset(new CompressedRowSparseMatrix(A));
  }

  if (problem_proto.b_size() > 0) {
    problem->b.reset(new double[problem_proto.b_size()]);
    for (int i = 0; i < problem_proto.b_size(); ++i) {
      problem->b[i] = problem_proto.b(i);
    }
  }

  if (problem_proto.d_size() > 0) {
    problem->D.reset(new double[problem_proto.d_size()]);
    for (int i = 0; i < problem_proto.d_size(); ++i) {
      problem->D[i] = problem_proto.d(i);
    }
  }

  if (problem_proto.d_size() > 0) {
    if (problem_proto.x_size() > 0) {
      problem->x_D.reset(new double[problem_proto.x_size()]);
      for (int i = 0; i < problem_proto.x_size(); ++i) {
        problem->x_D[i] = problem_proto.x(i);
      }
    }
  } else {
    if (problem_proto.x_size() > 0) {
      problem->x.reset(new double[problem_proto.x_size()]);
      for (int i = 0; i < problem_proto.x_size(); ++i) {
        problem->x[i] = problem_proto.x(i);
      }
    }
  }

  problem->num_eliminate_blocks = 0;
  if (problem_proto.has_num_eliminate_blocks()) {
    problem->num_eliminate_blocks = problem_proto.num_eliminate_blocks();
  }

  return problem;
}
#else
LinearLeastSquaresProblem* CreateLinearLeastSquaresProblemFromFile(
    const string& filename) {
  LOG(FATAL)
      << "Loading a least squares problem from disk requires "
      << "Ceres to be built with Protocol Buffers support.";
  return NULL;
}
#endif  // CERES_NO_PROTOCOL_BUFFERS

/*
A = [1   2]
    [3   4]
    [6 -10]

b = [  8
      18
     -18]

x = [2
     3]

D = [1
     2]

x_D = [1.78448275;
       2.82327586;]
 */
LinearLeastSquaresProblem* LinearLeastSquaresProblem0() {
  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;

  TripletSparseMatrix* A = new TripletSparseMatrix(3, 2, 6);
  problem->b.reset(new double[3]);
  problem->D.reset(new double[2]);

  problem->x.reset(new double[2]);
  problem->x_D.reset(new double[2]);

  int* Ai = A->mutable_rows();
  int* Aj = A->mutable_cols();
  double* Ax = A->mutable_values();

  int counter = 0;
  for (int i = 0; i < 3; ++i) {
    for (int j = 0; j< 2; ++j) {
      Ai[counter]=i;
      Aj[counter]=j;
      ++counter;
    }
  };

  Ax[0] = 1.;
  Ax[1] = 2.;
  Ax[2] = 3.;
  Ax[3] = 4.;
  Ax[4] = 6;
  Ax[5] = -10;
  A->set_num_nonzeros(6);
  problem->A.reset(A);

  problem->b[0] = 8;
  problem->b[1] = 18;
  problem->b[2] = -18;

  problem->x[0] = 2.0;
  problem->x[1] = 3.0;

  problem->D[0] = 1;
  problem->D[1] = 2;

  problem->x_D[0] = 1.78448275;
  problem->x_D[1] = 2.82327586;
  return problem;
}


/*
      A = [1 0  | 2 0 0
           3 0  | 0 4 0
           0 5  | 0 0 6
           0 7  | 8 0 0
           0 9  | 1 0 0
           0 0  | 1 1 1]

      b = [0
           1
           2
           3
           4
           5]

      c = A'* b = [ 3
                   67
                   33
                    9
                   17]

      A'A = [10    0    2   12   0
              0  155   65    0  30
              2   65   70    1   1
             12    0    1   17   1
              0   30    1    1  37]

      S = [ 42.3419  -1.4000  -11.5806
            -1.4000   2.6000    1.0000
            11.5806   1.0000   31.1935]

      r = [ 4.3032
            5.4000
            5.0323]

      S\r = [ 0.2102
              2.1367
              0.1388]

      A\b = [-2.3061
              0.3172
              0.2102
              2.1367
              0.1388]
*/
// The following two functions create a TripletSparseMatrix and a
// BlockSparseMatrix version of this problem.

// TripletSparseMatrix version.
LinearLeastSquaresProblem* LinearLeastSquaresProblem1() {
  int num_rows = 6;
  int num_cols = 5;

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;
  TripletSparseMatrix* A = new TripletSparseMatrix(num_rows,
                                                   num_cols,
                                                   num_rows * num_cols);
  problem->b.reset(new double[num_rows]);
  problem->D.reset(new double[num_cols]);
  problem->num_eliminate_blocks = 2;

  int* rows = A->mutable_rows();
  int* cols = A->mutable_cols();
  double* values = A->mutable_values();

  int nnz = 0;

  // Row 1
  {
    rows[nnz] = 0;
    cols[nnz] = 0;
    values[nnz++] = 1;

    rows[nnz] = 0;
    cols[nnz] = 2;
    values[nnz++] = 2;
  }

  // Row 2
  {
    rows[nnz] = 1;
    cols[nnz] = 0;
    values[nnz++] = 3;

    rows[nnz] = 1;
    cols[nnz] = 3;
    values[nnz++] = 4;
  }

  // Row 3
  {
    rows[nnz] = 2;
    cols[nnz] = 1;
    values[nnz++] = 5;

    rows[nnz] = 2;
    cols[nnz] = 4;
    values[nnz++] = 6;
  }

  // Row 4
  {
    rows[nnz] = 3;
    cols[nnz] = 1;
    values[nnz++] = 7;

    rows[nnz] = 3;
    cols[nnz] = 2;
    values[nnz++] = 8;
  }

  // Row 5
  {
    rows[nnz] = 4;
    cols[nnz] = 1;
    values[nnz++] = 9;

    rows[nnz] = 4;
    cols[nnz] = 2;
    values[nnz++] = 1;
  }

  // Row 6
  {
    rows[nnz] = 5;
    cols[nnz] = 2;
    values[nnz++] = 1;

    rows[nnz] = 5;
    cols[nnz] = 3;
    values[nnz++] = 1;

    rows[nnz] = 5;
    cols[nnz] = 4;
    values[nnz++] = 1;
  }

  A->set_num_nonzeros(nnz);
  CHECK(A->IsValid());

  problem->A.reset(A);

  for (int i = 0; i < num_cols; ++i) {
    problem->D.get()[i] = 1;
  }

  for (int i = 0; i < num_rows; ++i) {
    problem->b.get()[i] = i;
  }

  return problem;
}

// BlockSparseMatrix version
LinearLeastSquaresProblem* LinearLeastSquaresProblem2() {
  int num_rows = 6;
  int num_cols = 5;

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;

  problem->b.reset(new double[num_rows]);
  problem->D.reset(new double[num_cols]);
  problem->num_eliminate_blocks = 2;

  CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
  scoped_array<double> values(new double[num_rows * num_cols]);

  for (int c = 0; c < num_cols; ++c) {
    bs->cols.push_back(Block());
    bs->cols.back().size = 1;
    bs->cols.back().position = c;
  }

  int nnz = 0;

  // Row 1
  {
    values[nnz++] = 1;
    values[nnz++] = 2;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 0;
    row.cells.push_back(Cell(0, 0));
    row.cells.push_back(Cell(2, 1));
  }

  // Row 2
  {
    values[nnz++] = 3;
    values[nnz++] = 4;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 1;
    row.cells.push_back(Cell(0, 2));
    row.cells.push_back(Cell(3, 3));
  }

  // Row 3
  {
    values[nnz++] = 5;
    values[nnz++] = 6;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 2;
    row.cells.push_back(Cell(1, 4));
    row.cells.push_back(Cell(4, 5));
  }

  // Row 4
  {
    values[nnz++] = 7;
    values[nnz++] = 8;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 3;
    row.cells.push_back(Cell(1, 6));
    row.cells.push_back(Cell(2, 7));
  }

  // Row 5
  {
    values[nnz++] = 9;
    values[nnz++] = 1;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 4;
    row.cells.push_back(Cell(1, 8));
    row.cells.push_back(Cell(2, 9));
  }

  // Row 6
  {
    values[nnz++] = 1;
    values[nnz++] = 1;
    values[nnz++] = 1;

    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 5;
    row.cells.push_back(Cell(2, 10));
    row.cells.push_back(Cell(3, 11));
    row.cells.push_back(Cell(4, 12));
  }

  BlockSparseMatrix* A = new BlockSparseMatrix(bs);
  memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));

  for (int i = 0; i < num_cols; ++i) {
    problem->D.get()[i] = 1;
  }

  for (int i = 0; i < num_rows; ++i) {
    problem->b.get()[i] = i;
  }

  problem->A.reset(A);

  return problem;
}


/*
      A = [1 0
           3 0
           0 5
           0 7
           0 9
           0 0]

      b = [0
           1
           2
           3
           4
           5]
*/
// BlockSparseMatrix version
LinearLeastSquaresProblem* LinearLeastSquaresProblem3() {
  int num_rows = 5;
  int num_cols = 2;

  LinearLeastSquaresProblem* problem = new LinearLeastSquaresProblem;

  problem->b.reset(new double[num_rows]);
  problem->D.reset(new double[num_cols]);
  problem->num_eliminate_blocks = 2;

  CompressedRowBlockStructure* bs = new CompressedRowBlockStructure;
  scoped_array<double> values(new double[num_rows * num_cols]);

  for (int c = 0; c < num_cols; ++c) {
    bs->cols.push_back(Block());
    bs->cols.back().size = 1;
    bs->cols.back().position = c;
  }

  int nnz = 0;

  // Row 1
  {
    values[nnz++] = 1;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 0;
    row.cells.push_back(Cell(0, 0));
  }

  // Row 2
  {
    values[nnz++] = 3;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 1;
    row.cells.push_back(Cell(0, 1));
  }

  // Row 3
  {
    values[nnz++] = 5;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 2;
    row.cells.push_back(Cell(1, 2));
  }

  // Row 4
  {
    values[nnz++] = 7;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 3;
    row.cells.push_back(Cell(1, 3));
  }

  // Row 5
  {
    values[nnz++] = 9;
    bs->rows.push_back(CompressedRow());
    CompressedRow& row = bs->rows.back();
    row.block.size = 1;
    row.block.position = 4;
    row.cells.push_back(Cell(1, 4));
  }

  BlockSparseMatrix* A = new BlockSparseMatrix(bs);
  memcpy(A->mutable_values(), values.get(), nnz * sizeof(*A->values()));

  for (int i = 0; i < num_cols; ++i) {
    problem->D.get()[i] = 1;
  }

  for (int i = 0; i < num_rows; ++i) {
    problem->b.get()[i] = i;
  }

  problem->A.reset(A);

  return problem;
}

bool DumpLinearLeastSquaresProblemToConsole(const string& directory,
                                            int iteration,
                                            const SparseMatrix* A,
                                            const double* D,
                                            const double* b,
                                            const double* x,
                                            int num_eliminate_blocks) {
  CHECK_NOTNULL(A);
  Matrix AA;
  A->ToDenseMatrix(&AA);
  LOG(INFO) << "A^T: \n" << AA.transpose();

  if (D != NULL) {
    LOG(INFO) << "A's appended diagonal:\n"
              << ConstVectorRef(D, A->num_cols());
  }

  if (b != NULL) {
    LOG(INFO) << "b: \n" << ConstVectorRef(b, A->num_rows());
  }

  if (x != NULL) {
    LOG(INFO) << "x: \n" << ConstVectorRef(x, A->num_cols());
  }
  return true;
};

#ifndef CERES_NO_PROTOCOL_BUFFERS
bool DumpLinearLeastSquaresProblemToProtocolBuffer(const string& directory,
                                                   int iteration,
                                                   const SparseMatrix* A,
                                                   const double* D,
                                                   const double* b,
                                                   const double* x,
                                                   int num_eliminate_blocks) {
  CHECK_NOTNULL(A);
  LinearLeastSquaresProblemProto lsqp;
  A->ToProto(lsqp.mutable_a());

  if (D != NULL) {
    for (int i = 0; i < A->num_cols(); ++i) {
      lsqp.add_d(D[i]);
    }
  }

  if (b != NULL) {
    for (int i = 0; i < A->num_rows(); ++i) {
      lsqp.add_b(b[i]);
    }
  }

  if (x != NULL) {
    for (int i = 0; i < A->num_cols(); ++i) {
      lsqp.add_x(x[i]);
    }
  }

  lsqp.set_num_eliminate_blocks(num_eliminate_blocks);
  string format_string = JoinPath(directory,
                                  "lm_iteration_%03d.lsqp");
  string filename =
      StringPrintf(format_string.c_str(),  iteration);
  LOG(INFO) << "Dumping least squares problem for iteration " << iteration
            << " to disk. File: " << filename;
  WriteStringToFileOrDie(lsqp.SerializeAsString(), filename);
  return true;
}
#else
bool DumpLinearLeastSquaresProblemToProtocolBuffer(const string& directory,
                                                   int iteration,
                                                   const SparseMatrix* A,
                                                   const double* D,
                                                   const double* b,
                                                   const double* x,
                                                   int num_eliminate_blocks) {
  LOG(ERROR) << "Dumping least squares problems is only "
             << "supported when Ceres is compiled with "
             << "protocol buffer support.";
  return false;
}
#endif

void WriteArrayToFileOrDie(const string& filename,
                           const double* x,
                           const int size) {
  CHECK_NOTNULL(x);
  VLOG(2) << "Writing array to: " << filename;
  FILE* fptr = fopen(filename.c_str(), "w");
  CHECK_NOTNULL(fptr);
  for (int i = 0; i < size; ++i) {
    fprintf(fptr, "%17f\n", x[i]);
  }
  fclose(fptr);
}

bool DumpLinearLeastSquaresProblemToTextFile(const string& directory,
                                             int iteration,
                                             const SparseMatrix* A,
                                             const double* D,
                                             const double* b,
                                             const double* x,
                                             int num_eliminate_blocks) {
  CHECK_NOTNULL(A);
  string format_string = JoinPath(directory,
                                  "lm_iteration_%03d");
  string filename_prefix =
      StringPrintf(format_string.c_str(), iteration);

  LOG(INFO) << "writing to: " << filename_prefix << "*";

  string matlab_script;
  StringAppendF(&matlab_script,
                "function lsqp = lm_iteration_%03d()\n", iteration);
  StringAppendF(&matlab_script,
                "lsqp.num_rows = %d;\n", A->num_rows());
  StringAppendF(&matlab_script,
                "lsqp.num_cols = %d;\n", A->num_cols());

  {
    string filename = filename_prefix + "_A.txt";
    FILE* fptr = fopen(filename.c_str(), "w");
    CHECK_NOTNULL(fptr);
    A->ToTextFile(fptr);
    fclose(fptr);
    StringAppendF(&matlab_script,
                  "tmp = load('%s', '-ascii');\n", filename.c_str());
    StringAppendF(
        &matlab_script,
        "lsqp.A = sparse(tmp(:, 1) + 1, tmp(:, 2) + 1, tmp(:, 3), %d, %d);\n",
        A->num_rows(),
        A->num_cols());
  }


  if (D != NULL) {
    string filename = filename_prefix + "_D.txt";
    WriteArrayToFileOrDie(filename, D, A->num_cols());
    StringAppendF(&matlab_script,
                  "lsqp.D = load('%s', '-ascii');\n", filename.c_str());
  }

  if (b != NULL) {
    string filename = filename_prefix + "_b.txt";
    WriteArrayToFileOrDie(filename, b, A->num_rows());
    StringAppendF(&matlab_script,
                  "lsqp.b = load('%s', '-ascii');\n", filename.c_str());
  }

  if (x != NULL) {
    string filename = filename_prefix + "_x.txt";
    WriteArrayToFileOrDie(filename, x, A->num_cols());
    StringAppendF(&matlab_script,
                  "lsqp.x = load('%s', '-ascii');\n", filename.c_str());
  }

  string matlab_filename = filename_prefix + ".m";
  WriteStringToFileOrDie(matlab_script, matlab_filename);
  return true;
}

bool DumpLinearLeastSquaresProblem(const string& directory,
                              	   int iteration,
                                   DumpFormatType dump_format_type,
                                   const SparseMatrix* A,
                                   const double* D,
                                   const double* b,
                                   const double* x,
                                   int num_eliminate_blocks) {
  switch (dump_format_type) {
    case (CONSOLE):
      return DumpLinearLeastSquaresProblemToConsole(directory,
                                                    iteration,
                                                    A, D, b, x,
                                                    num_eliminate_blocks);
    case (PROTOBUF):
      return DumpLinearLeastSquaresProblemToProtocolBuffer(
          directory,
          iteration,
          A, D, b, x,
          num_eliminate_blocks);
    case (TEXTFILE):
      return DumpLinearLeastSquaresProblemToTextFile(directory,
                                                     iteration,
                                                     A, D, b, x,
                                                     num_eliminate_blocks);
    default:
      LOG(FATAL) << "Unknown DumpFormatType " << dump_format_type;
  };

  return true;
}

}  // namespace internal
}  // namespace ceres