// Copyright (c) 2010 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/common/sqlite_utils.h"

#include <list>

#include "base/file_path.h"
#include "base/lazy_instance.h"
#include "base/logging.h"
#include "base/stl_util-inl.h"
#include "base/string16.h"
#include "base/synchronization/lock.h"

// The vanilla error handler implements the common fucntionality for all the
// error handlers. Specialized error handlers are expected to only override
// the Handler() function.
class VanillaSQLErrorHandler : public SQLErrorHandler {
 public:
  VanillaSQLErrorHandler() : error_(SQLITE_OK) {
  }
  virtual int GetLastError() const {
    return error_;
  }
 protected:
  int error_;
};

class DebugSQLErrorHandler: public VanillaSQLErrorHandler {
 public:
  virtual int HandleError(int error, sqlite3* db) {
    error_ = error;
    NOTREACHED() << "sqlite error " << error
                 << " db " << static_cast<void*>(db);
    return error;
  }
};

class ReleaseSQLErrorHandler : public VanillaSQLErrorHandler {
 public:
  virtual int HandleError(int error, sqlite3* db) {
    error_ = error;
    // Used to have a CHECK here. Got lots of crashes.
    return error;
  }
};

// The default error handler factory is also in charge of managing the
// lifetime of the error objects. This object is multi-thread safe.
class DefaultSQLErrorHandlerFactory : public SQLErrorHandlerFactory {
 public:
  ~DefaultSQLErrorHandlerFactory() {
    STLDeleteContainerPointers(errors_.begin(), errors_.end());
  }

  virtual SQLErrorHandler* Make() {
    SQLErrorHandler* handler;
#ifndef NDEBUG
    handler = new DebugSQLErrorHandler;
#else
    handler = new ReleaseSQLErrorHandler;
#endif  // NDEBUG
    AddHandler(handler);
    return handler;
  }

 private:
  void AddHandler(SQLErrorHandler* handler) {
    base::AutoLock lock(lock_);
    errors_.push_back(handler);
  }

  typedef std::list<SQLErrorHandler*> ErrorList;
  ErrorList errors_;
  base::Lock lock_;
};

static base::LazyInstance<DefaultSQLErrorHandlerFactory>
    g_default_sql_error_handler_factory(base::LINKER_INITIALIZED);

SQLErrorHandlerFactory* GetErrorHandlerFactory() {
  // TODO(cpu): Testing needs to override the error handler.
  // Destruction of DefaultSQLErrorHandlerFactory handled by at_exit manager.
  return g_default_sql_error_handler_factory.Pointer();
}

namespace sqlite_utils {

int OpenSqliteDb(const FilePath& filepath, sqlite3** database) {
#if defined(OS_WIN)
  // We want the default encoding to always be UTF-8, so we use the
  // 8-bit version of open().
  return sqlite3_open(WideToUTF8(filepath.value()).c_str(), database);
#elif defined(OS_POSIX)
  return sqlite3_open(filepath.value().c_str(), database);
#endif
}

bool DoesSqliteTableExist(sqlite3* db,
                          const char* db_name,
                          const char* table_name) {
  // sqlite doesn't allow binding parameters as table names, so we have to
  // manually construct the sql
  std::string sql("SELECT name FROM ");
  if (db_name && db_name[0]) {
    sql.append(db_name);
    sql.push_back('.');
  }
  sql.append("sqlite_master WHERE type='table' AND name=?");

  SQLStatement statement;
  if (statement.prepare(db, sql.c_str()) != SQLITE_OK)
    return false;

  if (statement.bind_text(0, table_name) != SQLITE_OK)
    return false;

  // we only care about if this matched a row, not the actual data
  return sqlite3_step(statement.get()) == SQLITE_ROW;
}

bool DoesSqliteColumnExist(sqlite3* db,
                           const char* database_name,
                           const char* table_name,
                           const char* column_name,
                           const char* column_type) {
  SQLStatement s;
  std::string sql;
  sql.append("PRAGMA ");
  if (database_name && database_name[0]) {
    // optional database name specified
    sql.append(database_name);
    sql.push_back('.');
  }
  sql.append("TABLE_INFO(");
  sql.append(table_name);
  sql.append(")");

  if (s.prepare(db, sql.c_str()) != SQLITE_OK)
    return false;

  while (s.step() == SQLITE_ROW) {
    if (!s.column_string(1).compare(column_name)) {
      if (column_type && column_type[0])
        return !s.column_string(2).compare(column_type);
      return true;
    }
  }
  return false;
}

bool DoesSqliteTableHaveRow(sqlite3* db, const char* table_name) {
  SQLStatement s;
  std::string b;
  b.append("SELECT * FROM ");
  b.append(table_name);

  if (s.prepare(db, b.c_str()) != SQLITE_OK)
    return false;

  return s.step() == SQLITE_ROW;
}

}  // namespace sqlite_utils

SQLTransaction::SQLTransaction(sqlite3* db) : db_(db), began_(false) {
}

SQLTransaction::~SQLTransaction() {
  if (began_) {
    Rollback();
  }
}

int SQLTransaction::BeginCommand(const char* command) {
  int rv = SQLITE_ERROR;
  if (!began_ && db_) {
    rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    began_ = (rv == SQLITE_OK);
  }
  return rv;
}

int SQLTransaction::EndCommand(const char* command) {
  int rv = SQLITE_ERROR;
  if (began_ && db_) {
    rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    began_ = (rv != SQLITE_OK);
  }
  return rv;
}

SQLNestedTransactionSite::~SQLNestedTransactionSite() {
  DCHECK(!top_transaction_);
}

void SQLNestedTransactionSite::SetTopTransaction(SQLNestedTransaction* top) {
  DCHECK(!top || !top_transaction_);
  top_transaction_ = top;
}

SQLNestedTransaction::SQLNestedTransaction(SQLNestedTransactionSite* site)
  : SQLTransaction(site->GetSqlite3DB()),
    needs_rollback_(false),
    site_(site) {
  DCHECK(site);
  if (site->GetTopTransaction() == NULL) {
    site->SetTopTransaction(this);
  }
}

SQLNestedTransaction::~SQLNestedTransaction() {
  if (began_) {
    Rollback();
  }
  if (site_->GetTopTransaction() == this) {
    site_->SetTopTransaction(NULL);
  }
}

int SQLNestedTransaction::BeginCommand(const char* command) {
  DCHECK(db_);
  DCHECK(site_ && site_->GetTopTransaction());
  if (!db_ || began_) {
    return SQLITE_ERROR;
  }
  if (site_->GetTopTransaction() == this) {
    int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
    began_ = (rv == SQLITE_OK);
    if (began_) {
      site_->OnBegin();
    }
    return rv;
  } else {
    if (site_->GetTopTransaction()->needs_rollback_) {
      return SQLITE_ERROR;
    }
    began_ = true;
    return SQLITE_OK;
  }
}

int SQLNestedTransaction::EndCommand(const char* command) {
  DCHECK(db_);
  DCHECK(site_ && site_->GetTopTransaction());
  if (!db_ || !began_) {
    return SQLITE_ERROR;
  }
  if (site_->GetTopTransaction() == this) {
    if (needs_rollback_) {
      sqlite3_exec(db_, "ROLLBACK", NULL, NULL, NULL);
      began_ = false;  // reset so we don't try to rollback or call
                       // OnRollback() again
      site_->OnRollback();
      return SQLITE_ERROR;
    } else {
      int rv = sqlite3_exec(db_, command, NULL, NULL, NULL);
      began_ = (rv != SQLITE_OK);
      if (strcmp(command, "ROLLBACK") == 0) {
        began_ = false;  // reset so we don't try to rollbck or call
                         // OnRollback() again
        site_->OnRollback();
      } else {
        DCHECK(strcmp(command, "COMMIT") == 0);
        if (rv == SQLITE_OK) {
          site_->OnCommit();
        }
      }
      return rv;
    }
  } else {
    if (strcmp(command, "ROLLBACK") == 0) {
      site_->GetTopTransaction()->needs_rollback_ = true;
    }
    began_ = false;
    return SQLITE_OK;
  }
}

int SQLStatement::prepare(sqlite3* db, const char* sql, int sql_len) {
  DCHECK(!stmt_);
  int rv = sqlite3_prepare_v2(db, sql, sql_len, &stmt_, NULL);
  if (rv != SQLITE_OK) {
    SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
    return error_handler->HandleError(rv, db);
  }
  return rv;
}

int SQLStatement::step() {
  DCHECK(stmt_);
  int status = sqlite3_step(stmt_);
  if ((status == SQLITE_ROW) || (status == SQLITE_DONE))
    return status;
  // We got a problem.
  SQLErrorHandler* error_handler = GetErrorHandlerFactory()->Make();
  return error_handler->HandleError(status, db_handle());
}

int SQLStatement::reset() {
  DCHECK(stmt_);
  return sqlite3_reset(stmt_);
}

sqlite_int64 SQLStatement::last_insert_rowid() {
  DCHECK(stmt_);
  return sqlite3_last_insert_rowid(db_handle());
}

int SQLStatement::changes() {
  DCHECK(stmt_);
  return sqlite3_changes(db_handle());
}

sqlite3* SQLStatement::db_handle() {
  DCHECK(stmt_);
  return sqlite3_db_handle(stmt_);
}

int SQLStatement::bind_parameter_count() {
  DCHECK(stmt_);
  return sqlite3_bind_parameter_count(stmt_);
}

int SQLStatement::bind_blob(int index, std::vector<unsigned char>* blob) {
  if (blob) {
    const void* value = blob->empty() ? NULL : &(*blob)[0];
    int len = static_cast<int>(blob->size());
    return bind_blob(index, value, len);
  } else {
    return bind_null(index);
  }
}

int SQLStatement::bind_blob(int index, const void* value, int value_len) {
   return bind_blob(index, value, value_len, SQLITE_TRANSIENT);
}

int SQLStatement::bind_blob(int index, const void* value, int value_len,
                            Function dtor) {
  DCHECK(stmt_);
  return sqlite3_bind_blob(stmt_, index + 1, value, value_len, dtor);
}

int SQLStatement::bind_double(int index, double value) {
  DCHECK(stmt_);
  return sqlite3_bind_double(stmt_, index + 1, value);
}

int SQLStatement::bind_bool(int index, bool value) {
  DCHECK(stmt_);
  return sqlite3_bind_int(stmt_, index + 1, value);
}

int SQLStatement::bind_int(int index, int value) {
  DCHECK(stmt_);
  return sqlite3_bind_int(stmt_, index + 1, value);
}

int SQLStatement::bind_int64(int index, sqlite_int64 value) {
  DCHECK(stmt_);
  return sqlite3_bind_int64(stmt_, index + 1, value);
}

int SQLStatement::bind_null(int index) {
  DCHECK(stmt_);
  return sqlite3_bind_null(stmt_, index + 1);
}

int SQLStatement::bind_text(int index, const char* value, int value_len,
              Function dtor) {
  DCHECK(stmt_);
  return sqlite3_bind_text(stmt_, index + 1, value, value_len, dtor);
}

int SQLStatement::bind_text16(int index, const char16* value, int value_len,
                Function dtor) {
  DCHECK(stmt_);
  value_len *= sizeof(char16);
  return sqlite3_bind_text16(stmt_, index + 1, value, value_len, dtor);
}

int SQLStatement::bind_value(int index, const sqlite3_value* value) {
  DCHECK(stmt_);
  return sqlite3_bind_value(stmt_, index + 1, value);
}

int SQLStatement::column_count() {
  DCHECK(stmt_);
  return sqlite3_column_count(stmt_);
}

int SQLStatement::column_type(int index) {
  DCHECK(stmt_);
  return sqlite3_column_type(stmt_, index);
}

const void* SQLStatement::column_blob(int index) {
  DCHECK(stmt_);
  return sqlite3_column_blob(stmt_, index);
}

bool SQLStatement::column_blob_as_vector(int index,
                                         std::vector<unsigned char>* blob) {
  DCHECK(stmt_);
  const void* p = column_blob(index);
  size_t len = column_bytes(index);
  blob->resize(len);
  if (blob->size() != len) {
    return false;
  }
  if (len > 0)
    memcpy(&(blob->front()), p, len);
  return true;
}

bool SQLStatement::column_blob_as_string(int index, std::string* blob) {
  DCHECK(stmt_);
  const void* p = column_blob(index);
  size_t len = column_bytes(index);
  blob->resize(len);
  if (blob->size() != len) {
    return false;
  }
  blob->assign(reinterpret_cast<const char*>(p), len);
  return true;
}

int SQLStatement::column_bytes(int index) {
  DCHECK(stmt_);
  return sqlite3_column_bytes(stmt_, index);
}

int SQLStatement::column_bytes16(int index) {
  DCHECK(stmt_);
  return sqlite3_column_bytes16(stmt_, index);
}

double SQLStatement::column_double(int index) {
  DCHECK(stmt_);
  return sqlite3_column_double(stmt_, index);
}

bool SQLStatement::column_bool(int index) {
  DCHECK(stmt_);
  return sqlite3_column_int(stmt_, index) ? true : false;
}

int SQLStatement::column_int(int index) {
  DCHECK(stmt_);
  return sqlite3_column_int(stmt_, index);
}

sqlite_int64 SQLStatement::column_int64(int index) {
  DCHECK(stmt_);
  return sqlite3_column_int64(stmt_, index);
}

const char* SQLStatement::column_text(int index) {
  DCHECK(stmt_);
  return reinterpret_cast<const char*>(sqlite3_column_text(stmt_, index));
}

bool SQLStatement::column_string(int index, std::string* str) {
  DCHECK(stmt_);
  DCHECK(str);
  const char* s = column_text(index);
  str->assign(s ? s : std::string());
  return s != NULL;
}

std::string SQLStatement::column_string(int index) {
  std::string str;
  column_string(index, &str);
  return str;
}

const char16* SQLStatement::column_text16(int index) {
  DCHECK(stmt_);
  return static_cast<const char16*>(sqlite3_column_text16(stmt_, index));
}

bool SQLStatement::column_string16(int index, string16* str) {
  DCHECK(stmt_);
  DCHECK(str);
  const char* s = column_text(index);
  str->assign(s ? UTF8ToUTF16(s) : string16());
  return (s != NULL);
}

string16 SQLStatement::column_string16(int index) {
  string16 str;
  column_string16(index, &str);
  return str;
}

bool SQLStatement::column_wstring(int index, std::wstring* str) {
  DCHECK(stmt_);
  DCHECK(str);
  const char* s = column_text(index);
  str->assign(s ? UTF8ToWide(s) : std::wstring());
  return (s != NULL);
}

std::wstring SQLStatement::column_wstring(int index) {
  std::wstring wstr;
  column_wstring(index, &wstr);
  return wstr;
}