普通文本  |  479行  |  15.86 KB

// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/history/visitsegment_database.h"

#include <math.h>

#include <algorithm>
#include <string>
#include <vector>

#include "base/command_line.h"
#include "base/logging.h"
#include "base/stl_util.h"
#include "base/strings/string_util.h"
#include "base/strings/utf_string_conversions.h"
#include "chrome/browser/history/page_usage_data.h"
#include "chrome/common/chrome_switches.h"
#include "sql/statement.h"
#include "sql/transaction.h"

// The following tables are used to store url segment information.
//
// segments
//   id                 Primary key
//   name               A unique string to represent that segment. (URL derived)
//   url_id             ID of the url currently used to represent this segment.
//
// segment_usage
//   id                 Primary key
//   segment_id         Corresponding segment id
//   time_slot          time stamp identifying for what day this entry is about
//   visit_count        Number of visit in the segment
//
// segment_duration
//   id                 Primary key
//   segment_id         Corresponding segment id
//   time_slot          time stamp identifying what day this entry is for
//   duration           Total time during the time_slot the user has been on
//                      the page. This is a serialized TimeDelta value.
// segment_duration is only created if chrome::kTrackActiveVisitTime is set.

namespace history {

VisitSegmentDatabase::VisitSegmentDatabase()
    : has_duration_table_(CommandLine::ForCurrentProcess()->HasSwitch(
                              switches::kTrackActiveVisitTime)) {
}

VisitSegmentDatabase::~VisitSegmentDatabase() {
}

bool VisitSegmentDatabase::InitSegmentTables() {
  // Segments table.
  if (!GetDB().DoesTableExist("segments")) {
    if (!GetDB().Execute("CREATE TABLE segments ("
        "id INTEGER PRIMARY KEY,"
        "name VARCHAR,"
        "url_id INTEGER NON NULL)")) {
      return false;
    }

    if (!GetDB().Execute(
        "CREATE INDEX segments_name ON segments(name)")) {
      return false;
    }
  }

  // This was added later, so we need to try to create it even if the table
  // already exists.
  if (!GetDB().Execute("CREATE INDEX IF NOT EXISTS segments_url_id ON "
                       "segments(url_id)"))
    return false;

  // Segment usage table.
  if (!GetDB().DoesTableExist("segment_usage")) {
    if (!GetDB().Execute("CREATE TABLE segment_usage ("
        "id INTEGER PRIMARY KEY,"
        "segment_id INTEGER NOT NULL,"
        "time_slot INTEGER NOT NULL,"
        "visit_count INTEGER DEFAULT 0 NOT NULL)")) {
      return false;
    }
    if (!GetDB().Execute(
        "CREATE INDEX segment_usage_time_slot_segment_id ON "
        "segment_usage(time_slot, segment_id)")) {
      return false;
    }
  }

  // Added in a later version, so we always need to try to creat this index.
  if (!GetDB().Execute("CREATE INDEX IF NOT EXISTS segments_usage_seg_id "
                       "ON segment_usage(segment_id)"))
    return false;

  // TODO(sky): if we decide to keep this feature duration should be added to
  // segument_usage.
  if (has_duration_table_ && !GetDB().DoesTableExist("segment_duration")) {
    if (!GetDB().Execute("CREATE TABLE segment_duration ("
                         "id INTEGER PRIMARY KEY,"
                         "segment_id INTEGER NOT NULL,"
                         "time_slot INTEGER NOT NULL,"
                         "duration INTEGER DEFAULT 0 NOT NULL)")) {
      return false;
    }
    if (!GetDB().Execute(
            "CREATE INDEX segment_duration_time_slot_segment_id ON "
            "segment_duration(time_slot, segment_id)")) {
      return false;
    }
  } else if (!has_duration_table_ &&
             !GetDB().Execute("DROP TABLE IF EXISTS segment_duration")) {
    return false;
  }

  return true;
}

bool VisitSegmentDatabase::DropSegmentTables() {
  // Dropping the tables will implicitly delete the indices.
  return GetDB().Execute("DROP TABLE segments") &&
         GetDB().Execute("DROP TABLE segment_usage") &&
         GetDB().Execute("DROP TABLE IF EXISTS segment_duration");
}

// Note: the segment name is derived from the URL but is not a URL. It is
// a string that can be easily recreated from various URLS. Maybe this should
// be an MD5 to limit the length.
//
// static
std::string VisitSegmentDatabase::ComputeSegmentName(const GURL& url) {
  // TODO(brettw) this should probably use the registry controlled
  // domains service.
  GURL::Replacements r;
  const char kWWWDot[] = "www.";
  const int kWWWDotLen = arraysize(kWWWDot) - 1;

  std::string host = url.host();
  const char* host_c = host.c_str();
  // Remove www. to avoid some dups.
  if (static_cast<int>(host.size()) > kWWWDotLen &&
      LowerCaseEqualsASCII(host_c, host_c + kWWWDotLen, kWWWDot)) {
    r.SetHost(host.c_str(),
              url_parse::Component(kWWWDotLen,
                  static_cast<int>(host.size()) - kWWWDotLen));
  }
  // Remove other stuff we don't want.
  r.ClearUsername();
  r.ClearPassword();
  r.ClearQuery();
  r.ClearRef();
  r.ClearPort();

  return url.ReplaceComponents(r).spec();
}

// static
base::Time VisitSegmentDatabase::SegmentTime(base::Time time) {
  return time.LocalMidnight();
}

SegmentID VisitSegmentDatabase::GetSegmentNamed(
    const std::string& segment_name) {
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT id FROM segments WHERE name = ?"));
  statement.BindString(0, segment_name);

  if (statement.Step())
    return statement.ColumnInt64(0);
  return 0;
}

bool VisitSegmentDatabase::UpdateSegmentRepresentationURL(SegmentID segment_id,
                                                          URLID url_id) {
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "UPDATE segments SET url_id = ? WHERE id = ?"));
  statement.BindInt64(0, url_id);
  statement.BindInt64(1, segment_id);

  return statement.Run();
}

URLID VisitSegmentDatabase::GetSegmentRepresentationURL(SegmentID segment_id) {
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT url_id FROM segments WHERE id = ?"));
  statement.BindInt64(0, segment_id);

  if (statement.Step())
    return statement.ColumnInt64(0);
  return 0;
}

SegmentID VisitSegmentDatabase::CreateSegment(URLID url_id,
                                              const std::string& segment_name) {
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "INSERT INTO segments (name, url_id) VALUES (?,?)"));
  statement.BindString(0, segment_name);
  statement.BindInt64(1, url_id);

  if (statement.Run())
    return GetDB().GetLastInsertRowId();
  return 0;
}

bool VisitSegmentDatabase::IncreaseSegmentVisitCount(SegmentID segment_id,
                                                     base::Time ts,
                                                     int amount) {
  base::Time t = SegmentTime(ts);

  sql::Statement select(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT id, visit_count FROM segment_usage "
      "WHERE time_slot = ? AND segment_id = ?"));
  select.BindInt64(0, t.ToInternalValue());
  select.BindInt64(1, segment_id);

  if (!select.is_valid())
    return false;

  if (select.Step()) {
    sql::Statement update(GetDB().GetCachedStatement(SQL_FROM_HERE,
        "UPDATE segment_usage SET visit_count = ? WHERE id = ?"));
    update.BindInt64(0, select.ColumnInt64(1) + static_cast<int64>(amount));
    update.BindInt64(1, select.ColumnInt64(0));

    return update.Run();
  } else {
    sql::Statement insert(GetDB().GetCachedStatement(SQL_FROM_HERE,
        "INSERT INTO segment_usage "
        "(segment_id, time_slot, visit_count) VALUES (?, ?, ?)"));
    insert.BindInt64(0, segment_id);
    insert.BindInt64(1, t.ToInternalValue());
    insert.BindInt64(2, static_cast<int64>(amount));

    return insert.Run();
  }
}

void VisitSegmentDatabase::QuerySegmentUsage(
    base::Time from_time,
    int max_result_count,
    std::vector<PageUsageData*>* result) {
  // Gather all the segment scores.
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT segment_id, time_slot, visit_count "
      "FROM segment_usage WHERE time_slot >= ? "
      "ORDER BY segment_id"));
  if (!statement.is_valid())
    return;

  QuerySegmentsCommon(&statement, from_time, max_result_count,
                      QUERY_VISIT_COUNT, result);
}

bool VisitSegmentDatabase::DeleteSegmentData(base::Time older_than) {
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "DELETE FROM segment_usage WHERE time_slot < ?"));
  statement.BindInt64(0, SegmentTime(older_than).ToInternalValue());

  if (!statement.Run())
    return false;

  if (!has_duration_table_)
    return true;

  sql::Statement duration_statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "DELETE FROM segment_duration WHERE time_slot < ?"));
  duration_statement.BindInt64(0, SegmentTime(older_than).ToInternalValue());

  return duration_statement.Run();
}

bool VisitSegmentDatabase::DeleteSegmentForURL(URLID url_id) {
  sql::Statement delete_usage(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "DELETE FROM segment_usage WHERE segment_id IN "
      "(SELECT id FROM segments WHERE url_id = ?)"));
  delete_usage.BindInt64(0, url_id);

  if (!delete_usage.Run())
    return false;

  if (has_duration_table_) {
    sql::Statement delete_duration(GetDB().GetCachedStatement(SQL_FROM_HERE,
        "DELETE FROM segment_duration WHERE segment_id IN "
        "(SELECT id FROM segments WHERE url_id = ?)"));
    delete_duration.BindInt64(0, url_id);

    if (!delete_duration.Run())
      return false;
  }

  sql::Statement delete_seg(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "DELETE FROM segments WHERE url_id = ?"));
  delete_seg.BindInt64(0, url_id);

  return delete_seg.Run();
}

SegmentDurationID VisitSegmentDatabase::CreateSegmentDuration(
    SegmentID segment_id,
    base::Time time,
    base::TimeDelta delta) {
  if (!has_duration_table_)
    return 0;

  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "INSERT INTO segment_duration (segment_id, time_slot, duration) "
      "VALUES (?,?,?)"));
  statement.BindInt64(0, segment_id);
  statement.BindInt64(1, SegmentTime(time).ToInternalValue());
  statement.BindInt64(2, delta.ToInternalValue());
  return statement.Run() ? GetDB().GetLastInsertRowId() : 0;
}

bool VisitSegmentDatabase::SetSegmentDuration(SegmentDurationID duration_id,
                                              base::TimeDelta time_delta) {
  if (!has_duration_table_)
    return false;

  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "UPDATE segment_duration SET duration = ? WHERE id = ?"));
  statement.BindInt64(0, time_delta.ToInternalValue());
  statement.BindInt64(1, duration_id);
  return statement.Run();
}

bool VisitSegmentDatabase::GetSegmentDuration(SegmentID segment_id,
                                              base::Time time,
                                              SegmentDurationID* duration_id,
                                              base::TimeDelta* time_delta) {
  if (!has_duration_table_)
    return false;

  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT id, duration FROM segment_duration "
      "WHERE segment_id = ? AND time_slot = ? "));
  if (!statement.is_valid())
    return false;

  statement.BindInt64(0, segment_id);
  statement.BindInt64(1, SegmentTime(time).ToInternalValue());

  if (!statement.Step())
    return false;

  *duration_id = statement.ColumnInt64(0);
  *time_delta = base::TimeDelta::FromInternalValue(statement.ColumnInt64(1));
  return true;
}

void VisitSegmentDatabase::QuerySegmentDuration(
    base::Time from_time,
    int max_result_count,
    std::vector<PageUsageData*>* result) {
  if (!has_duration_table_)
    return;

  // Gather all the segment scores.
  sql::Statement statement(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT segment_id, time_slot, duration "
      "FROM segment_duration WHERE time_slot >= ? "
      "ORDER BY segment_id"));
  if (!statement.is_valid())
    return;

  QuerySegmentsCommon(&statement, from_time, max_result_count, QUERY_DURATION,
                      result);
}

bool VisitSegmentDatabase::MigratePresentationIndex() {
  sql::Transaction transaction(&GetDB());
  return transaction.Begin() &&
      GetDB().Execute("DROP TABLE presentation") &&
      GetDB().Execute("CREATE TABLE segments_tmp ("
                      "id INTEGER PRIMARY KEY,"
                      "name VARCHAR,"
                      "url_id INTEGER NON NULL)") &&
      GetDB().Execute("INSERT INTO segments_tmp SELECT "
                      "id, name, url_id FROM segments") &&
      GetDB().Execute("DROP TABLE segments") &&
      GetDB().Execute("ALTER TABLE segments_tmp RENAME TO segments") &&
      transaction.Commit();
}


void VisitSegmentDatabase::QuerySegmentsCommon(
    sql::Statement* statement,
    base::Time from_time,
    int max_result_count,
    QueryType query_type,
    std::vector<PageUsageData*>* result) {
  // This function gathers the highest-ranked segments in two queries.
  // The first gathers scores for all segments.
  // The second gathers segment data (url, title, etc.) for the highest-ranked
  // segments.

  base::Time ts = SegmentTime(from_time);
  statement->BindInt64(0, ts.ToInternalValue());

  base::Time now = base::Time::Now();
  SegmentID last_segment_id = 0;
  PageUsageData* pud = NULL;
  float score = 0;
  base::TimeDelta duration;
  while (statement->Step()) {
    SegmentID segment_id = statement->ColumnInt64(0);
    if (segment_id != last_segment_id) {
      if (pud) {
        pud->SetScore(score);
        pud->SetDuration(duration);
        result->push_back(pud);
      }

      pud = new PageUsageData(segment_id);
      score = 0;
      last_segment_id = segment_id;
      duration = base::TimeDelta();
    }

    base::Time timeslot =
        base::Time::FromInternalValue(statement->ColumnInt64(1));
    int count;
    if (query_type == QUERY_VISIT_COUNT) {
      count = statement->ColumnInt(2);
    } else {
      base::TimeDelta current_duration(
          base::TimeDelta::FromInternalValue(statement->ColumnInt64(2)));
      duration += current_duration;
      // Souldn't overflow since we group by day.
      count = static_cast<int>(current_duration.InSeconds());
    }
    float day_score = 1.0f + log(static_cast<float>(count));

    // Recent visits count more than historical ones, so we multiply in a boost
    // related to how long ago this day was.
    // This boost is a curve that smoothly goes through these values:
    // Today gets 3x, a week ago 2x, three weeks ago 1.5x, falling off to 1x
    // at the limit of how far we reach into the past.
    int days_ago = (now - timeslot).InDays();
    float recency_boost = 1.0f + (2.0f * (1.0f / (1.0f + days_ago/7.0f)));
    score += recency_boost * day_score;
  }

  if (pud) {
    pud->SetScore(score);
    pud->SetDuration(duration);
    result->push_back(pud);
  }

  // Limit to the top kResultCount results.
  std::sort(result->begin(), result->end(), PageUsageData::Predicate);
  if (static_cast<int>(result->size()) > max_result_count) {
    STLDeleteContainerPointers(result->begin() + max_result_count,
                               result->end());
    result->resize(max_result_count);
  }

  // Now fetch the details about the entries we care about.
  sql::Statement statement2(GetDB().GetCachedStatement(SQL_FROM_HERE,
      "SELECT urls.url, urls.title FROM urls "
      "JOIN segments ON segments.url_id = urls.id "
      "WHERE segments.id = ?"));

  if (!statement2.is_valid())
    return;

  for (size_t i = 0; i < result->size(); ++i) {
    PageUsageData* pud = (*result)[i];
    statement2.BindInt64(0, pud->GetID());
    if (statement2.Step()) {
      pud->SetURL(GURL(statement2.ColumnString(0)));
      pud->SetTitle(statement2.ColumnString16(1));
    }
    statement2.Reset(true);
  }
}

}  // namespace history