// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/download/download_query.h"
#include <algorithm>
#include <string>
#include <vector>
#include "base/bind.h"
#include "base/callback.h"
#include "base/files/file_path.h"
#include "base/i18n/case_conversion.h"
#include "base/i18n/string_search.h"
#include "base/logging.h"
#include "base/memory/scoped_ptr.h"
#include "base/prefs/pref_service.h"
#include "base/stl_util.h"
#include "base/strings/string16.h"
#include "base/strings/string_split.h"
#include "base/strings/stringprintf.h"
#include "base/strings/utf_string_conversions.h"
#include "base/time/time.h"
#include "base/values.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/common/pref_names.h"
#include "content/public/browser/content_browser_client.h"
#include "content/public/browser/download_item.h"
#include "net/base/net_util.h"
#include "third_party/re2/re2/re2.h"
#include "url/gurl.h"
using content::DownloadDangerType;
using content::DownloadItem;
namespace {
// Templatized base::Value::GetAs*().
template <typename T> bool GetAs(const base::Value& in, T* out);
template<> bool GetAs(const base::Value& in, bool* out) {
return in.GetAsBoolean(out);
}
template<> bool GetAs(const base::Value& in, int* out) {
return in.GetAsInteger(out);
}
template<> bool GetAs(const base::Value& in, std::string* out) {
return in.GetAsString(out);
}
template<> bool GetAs(const base::Value& in, base::string16* out) {
return in.GetAsString(out);
}
template<> bool GetAs(const base::Value& in, std::vector<base::string16>* out) {
out->clear();
const base::ListValue* list = NULL;
if (!in.GetAsList(&list))
return false;
for (size_t i = 0; i < list->GetSize(); ++i) {
base::string16 element;
if (!list->GetString(i, &element)) {
out->clear();
return false;
}
out->push_back(element);
}
return true;
}
// The next several functions are helpers for making Callbacks that access
// DownloadItem fields.
static bool MatchesQuery(
const std::vector<base::string16>& query_terms,
const DownloadItem& item) {
DCHECK(!query_terms.empty());
base::string16 url_raw(UTF8ToUTF16(item.GetOriginalUrl().spec()));
base::string16 url_formatted = url_raw;
if (item.GetBrowserContext()) {
Profile* profile = Profile::FromBrowserContext(item.GetBrowserContext());
url_formatted = net::FormatUrl(
item.GetOriginalUrl(),
profile->GetPrefs()->GetString(prefs::kAcceptLanguages));
}
base::string16 path(item.GetTargetFilePath().LossyDisplayName());
for (std::vector<base::string16>::const_iterator it = query_terms.begin();
it != query_terms.end(); ++it) {
base::string16 term = base::i18n::ToLower(*it);
if (!base::i18n::StringSearchIgnoringCaseAndAccents(
term, url_raw, NULL, NULL) &&
!base::i18n::StringSearchIgnoringCaseAndAccents(
term, url_formatted, NULL, NULL) &&
!base::i18n::StringSearchIgnoringCaseAndAccents(
term, path, NULL, NULL)) {
return false;
}
}
return true;
}
static int64 GetStartTimeMsEpoch(const DownloadItem& item) {
return (item.GetStartTime() - base::Time::UnixEpoch()).InMilliseconds();
}
static int64 GetEndTimeMsEpoch(const DownloadItem& item) {
return (item.GetEndTime() - base::Time::UnixEpoch()).InMilliseconds();
}
std::string TimeToISO8601(const base::Time& t) {
base::Time::Exploded exploded;
t.UTCExplode(&exploded);
return base::StringPrintf(
"%04d-%02d-%02dT%02d:%02d:%02d.%03dZ", exploded.year, exploded.month,
exploded.day_of_month, exploded.hour, exploded.minute, exploded.second,
exploded.millisecond);
}
static std::string GetStartTime(const DownloadItem& item) {
return TimeToISO8601(item.GetStartTime());
}
static std::string GetEndTime(const DownloadItem& item) {
return TimeToISO8601(item.GetEndTime());
}
static bool GetDangerAccepted(const DownloadItem& item) {
return (item.GetDangerType() ==
content::DOWNLOAD_DANGER_TYPE_USER_VALIDATED);
}
static bool GetExists(const DownloadItem& item) {
return !item.GetFileExternallyRemoved();
}
static base::string16 GetFilename(const DownloadItem& item) {
// This filename will be compared with strings that could be passed in by the
// user, who only sees LossyDisplayNames.
return item.GetTargetFilePath().LossyDisplayName();
}
static std::string GetFilenameUTF8(const DownloadItem& item) {
return UTF16ToUTF8(GetFilename(item));
}
static std::string GetUrl(const DownloadItem& item) {
return item.GetOriginalUrl().spec();
}
static DownloadItem::DownloadState GetState(const DownloadItem& item) {
return item.GetState();
}
static DownloadDangerType GetDangerType(const DownloadItem& item) {
return item.GetDangerType();
}
static int GetReceivedBytes(const DownloadItem& item) {
return item.GetReceivedBytes();
}
static int GetTotalBytes(const DownloadItem& item) {
return item.GetTotalBytes();
}
static std::string GetMimeType(const DownloadItem& item) {
return item.GetMimeType();
}
static bool IsPaused(const DownloadItem& item) {
return item.IsPaused();
}
enum ComparisonType {LT, EQ, GT};
// Returns true if |item| matches the filter specified by |value|, |cmptype|,
// and |accessor|. |accessor| is conceptually a function that takes a
// DownloadItem and returns one of its fields, which is then compared to
// |value|.
template<typename ValueType>
static bool FieldMatches(
const ValueType& value,
ComparisonType cmptype,
const base::Callback<ValueType(const DownloadItem&)>& accessor,
const DownloadItem& item) {
switch (cmptype) {
case LT: return accessor.Run(item) < value;
case EQ: return accessor.Run(item) == value;
case GT: return accessor.Run(item) > value;
}
NOTREACHED();
return false;
}
// Helper for building a Callback to FieldMatches<>().
template <typename ValueType> DownloadQuery::FilterCallback BuildFilter(
const base::Value& value, ComparisonType cmptype,
ValueType (*accessor)(const DownloadItem&)) {
ValueType cpp_value;
if (!GetAs(value, &cpp_value)) return DownloadQuery::FilterCallback();
return base::Bind(&FieldMatches<ValueType>, cpp_value, cmptype,
base::Bind(accessor));
}
// Returns true if |accessor.Run(item)| matches |pattern|.
static bool FindRegex(
RE2* pattern,
const base::Callback<std::string(const DownloadItem&)>& accessor,
const DownloadItem& item) {
return RE2::PartialMatch(accessor.Run(item), *pattern);
}
// Helper for building a Callback to FindRegex().
DownloadQuery::FilterCallback BuildRegexFilter(
const base::Value& regex_value,
std::string (*accessor)(const DownloadItem&)) {
std::string regex_str;
if (!GetAs(regex_value, ®ex_str)) return DownloadQuery::FilterCallback();
scoped_ptr<RE2> pattern(new RE2(regex_str));
if (!pattern->ok()) return DownloadQuery::FilterCallback();
return base::Bind(&FindRegex, base::Owned(pattern.release()),
base::Bind(accessor));
}
// Returns a ComparisonType to indicate whether a field in |left| is less than,
// greater than or equal to the same field in |right|.
template<typename ValueType>
static ComparisonType Compare(
const base::Callback<ValueType(const DownloadItem&)>& accessor,
const DownloadItem& left, const DownloadItem& right) {
ValueType left_value = accessor.Run(left);
ValueType right_value = accessor.Run(right);
if (left_value > right_value) return GT;
if (left_value < right_value) return LT;
DCHECK_EQ(left_value, right_value);
return EQ;
}
} // anonymous namespace
DownloadQuery::DownloadQuery()
: limit_(kuint32max) {
}
DownloadQuery::~DownloadQuery() {
}
// AddFilter() pushes a new FilterCallback to filters_. Most FilterCallbacks are
// Callbacks to FieldMatches<>(). Search() iterates over given DownloadItems,
// discarding items for which any filter returns false. A DownloadQuery may have
// zero or more FilterCallbacks.
bool DownloadQuery::AddFilter(const DownloadQuery::FilterCallback& value) {
if (value.is_null()) return false;
filters_.push_back(value);
return true;
}
void DownloadQuery::AddFilter(DownloadItem::DownloadState state) {
AddFilter(base::Bind(&FieldMatches<DownloadItem::DownloadState>, state, EQ,
base::Bind(&GetState)));
}
void DownloadQuery::AddFilter(DownloadDangerType danger) {
AddFilter(base::Bind(&FieldMatches<DownloadDangerType>, danger, EQ,
base::Bind(&GetDangerType)));
}
bool DownloadQuery::AddFilter(DownloadQuery::FilterType type,
const base::Value& value) {
switch (type) {
case FILTER_BYTES_RECEIVED:
return AddFilter(BuildFilter<int>(value, EQ, &GetReceivedBytes));
case FILTER_DANGER_ACCEPTED:
return AddFilter(BuildFilter<bool>(value, EQ, &GetDangerAccepted));
case FILTER_EXISTS:
return AddFilter(BuildFilter<bool>(value, EQ, &GetExists));
case FILTER_FILENAME:
return AddFilter(BuildFilter<base::string16>(value, EQ, &GetFilename));
case FILTER_FILENAME_REGEX:
return AddFilter(BuildRegexFilter(value, &GetFilenameUTF8));
case FILTER_MIME:
return AddFilter(BuildFilter<std::string>(value, EQ, &GetMimeType));
case FILTER_PAUSED:
return AddFilter(BuildFilter<bool>(value, EQ, &IsPaused));
case FILTER_QUERY: {
std::vector<base::string16> query_terms;
return GetAs(value, &query_terms) &&
(query_terms.empty() ||
AddFilter(base::Bind(&MatchesQuery, query_terms)));
}
case FILTER_ENDED_AFTER:
return AddFilter(BuildFilter<std::string>(value, GT, &GetEndTime));
case FILTER_ENDED_BEFORE:
return AddFilter(BuildFilter<std::string>(value, LT, &GetEndTime));
case FILTER_END_TIME:
return AddFilter(BuildFilter<std::string>(value, EQ, &GetEndTime));
case FILTER_STARTED_AFTER:
return AddFilter(BuildFilter<std::string>(value, GT, &GetStartTime));
case FILTER_STARTED_BEFORE:
return AddFilter(BuildFilter<std::string>(value, LT, &GetStartTime));
case FILTER_START_TIME:
return AddFilter(BuildFilter<std::string>(value, EQ, &GetStartTime));
case FILTER_TOTAL_BYTES:
return AddFilter(BuildFilter<int>(value, EQ, &GetTotalBytes));
case FILTER_TOTAL_BYTES_GREATER:
return AddFilter(BuildFilter<int>(value, GT, &GetTotalBytes));
case FILTER_TOTAL_BYTES_LESS:
return AddFilter(BuildFilter<int>(value, LT, &GetTotalBytes));
case FILTER_URL:
return AddFilter(BuildFilter<std::string>(value, EQ, &GetUrl));
case FILTER_URL_REGEX:
return AddFilter(BuildRegexFilter(value, &GetUrl));
}
return false;
}
bool DownloadQuery::Matches(const DownloadItem& item) const {
for (FilterCallbackVector::const_iterator filter = filters_.begin();
filter != filters_.end(); ++filter) {
if (!filter->Run(item))
return false;
}
return true;
}
// AddSorter() creates a Sorter and pushes it onto sorters_. A Sorter is a
// direction and a Callback to Compare<>(). After filtering, Search() makes a
// DownloadComparator functor from the sorters_ and passes the
// DownloadComparator to std::partial_sort. std::partial_sort calls the
// DownloadComparator with different pairs of DownloadItems. DownloadComparator
// iterates over the sorters until a callback returns ComparisonType LT or GT.
// DownloadComparator returns true or false depending on that ComparisonType and
// the sorter's direction in order to indicate to std::partial_sort whether the
// left item is after or before the right item. If all sorters return EQ, then
// DownloadComparator compares GetId. A DownloadQuery may have zero or more
// Sorters, but there is one DownloadComparator per call to Search().
struct DownloadQuery::Sorter {
typedef base::Callback<ComparisonType(
const DownloadItem&, const DownloadItem&)> SortType;
template<typename ValueType>
static Sorter Build(DownloadQuery::SortDirection adirection,
ValueType (*accessor)(const DownloadItem&)) {
return Sorter(adirection, base::Bind(&Compare<ValueType>,
base::Bind(accessor)));
}
Sorter(DownloadQuery::SortDirection adirection,
const SortType& asorter)
: direction(adirection),
sorter(asorter) {
}
~Sorter() {}
DownloadQuery::SortDirection direction;
SortType sorter;
};
class DownloadQuery::DownloadComparator {
public:
explicit DownloadComparator(const DownloadQuery::SorterVector& terms)
: terms_(terms) {
}
// Returns true if |left| sorts before |right|.
bool operator() (const DownloadItem* left, const DownloadItem* right);
private:
const DownloadQuery::SorterVector& terms_;
// std::sort requires this class to be copyable.
};
bool DownloadQuery::DownloadComparator::operator() (
const DownloadItem* left, const DownloadItem* right) {
for (DownloadQuery::SorterVector::const_iterator term = terms_.begin();
term != terms_.end(); ++term) {
switch (term->sorter.Run(*left, *right)) {
case LT: return term->direction == DownloadQuery::ASCENDING;
case GT: return term->direction == DownloadQuery::DESCENDING;
case EQ: break; // break the switch but not the loop
}
}
CHECK_NE(left->GetId(), right->GetId());
return left->GetId() < right->GetId();
}
void DownloadQuery::AddSorter(DownloadQuery::SortType type,
DownloadQuery::SortDirection direction) {
switch (type) {
case SORT_END_TIME:
sorters_.push_back(Sorter::Build<int64>(direction, &GetEndTimeMsEpoch));
break;
case SORT_START_TIME:
sorters_.push_back(Sorter::Build<int64>(direction, &GetStartTimeMsEpoch));
break;
case SORT_URL:
sorters_.push_back(Sorter::Build<std::string>(direction, &GetUrl));
break;
case SORT_FILENAME:
sorters_.push_back(
Sorter::Build<base::string16>(direction, &GetFilename));
break;
case SORT_DANGER:
sorters_.push_back(Sorter::Build<DownloadDangerType>(
direction, &GetDangerType));
break;
case SORT_DANGER_ACCEPTED:
sorters_.push_back(Sorter::Build<bool>(direction, &GetDangerAccepted));
break;
case SORT_EXISTS:
sorters_.push_back(Sorter::Build<bool>(direction, &GetExists));
break;
case SORT_STATE:
sorters_.push_back(Sorter::Build<DownloadItem::DownloadState>(
direction, &GetState));
break;
case SORT_PAUSED:
sorters_.push_back(Sorter::Build<bool>(direction, &IsPaused));
break;
case SORT_MIME:
sorters_.push_back(Sorter::Build<std::string>(direction, &GetMimeType));
break;
case SORT_BYTES_RECEIVED:
sorters_.push_back(Sorter::Build<int>(direction, &GetReceivedBytes));
break;
case SORT_TOTAL_BYTES:
sorters_.push_back(Sorter::Build<int>(direction, &GetTotalBytes));
break;
}
}
void DownloadQuery::FinishSearch(DownloadQuery::DownloadVector* results) const {
if (!sorters_.empty())
std::partial_sort(results->begin(),
results->begin() + std::min(limit_, results->size()),
results->end(),
DownloadComparator(sorters_));
if (results->size() > limit_)
results->resize(limit_);
}