// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "webkit/browser/quota/usage_tracker.h"

#include <algorithm>
#include <deque>
#include <set>
#include <string>
#include <vector>

#include "base/bind.h"
#include "base/message_loop/message_loop_proxy.h"
#include "base/stl_util.h"
#include "net/base/net_util.h"

namespace quota {

namespace {

typedef ClientUsageTracker::OriginUsageAccumulator OriginUsageAccumulator;
typedef ClientUsageTracker::OriginSetByHost OriginSetByHost;

void DidGetOriginUsage(const OriginUsageAccumulator& accumulator,
                       const GURL& origin,
                       int64 usage) {
  accumulator.Run(origin, usage);
}

void DidGetHostUsage(const UsageCallback& callback,
                     int64 limited_usage,
                     int64 unlimited_usage) {
  DCHECK_GE(limited_usage, 0);
  DCHECK_GE(unlimited_usage, 0);
  callback.Run(limited_usage + unlimited_usage);
}

void NoopHostUsageCallback(int64 usage) {}

bool EraseOriginFromOriginSet(OriginSetByHost* origins_by_host,
                              const std::string& host,
                              const GURL& origin) {
  OriginSetByHost::iterator found = origins_by_host->find(host);
  if (found == origins_by_host->end())
    return false;

  if (!found->second.erase(origin))
    return false;

  if (found->second.empty())
    origins_by_host->erase(host);
  return true;
}

bool OriginSetContainsOrigin(const OriginSetByHost& origins,
                             const std::string& host,
                             const GURL& origin) {
  OriginSetByHost::const_iterator itr = origins.find(host);
  return itr != origins.end() && ContainsKey(itr->second, origin);
}

void DidGetGlobalUsageForLimitedGlobalUsage(const UsageCallback& callback,
                                            int64 total_global_usage,
                                            int64 global_unlimited_usage) {
  callback.Run(total_global_usage - global_unlimited_usage);
}

}  // namespace

// UsageTracker ----------------------------------------------------------

UsageTracker::UsageTracker(const QuotaClientList& clients,
                           StorageType type,
                           SpecialStoragePolicy* special_storage_policy)
    : type_(type),
      weak_factory_(this) {
  for (QuotaClientList::const_iterator iter = clients.begin();
      iter != clients.end();
      ++iter) {
    if ((*iter)->DoesSupport(type)) {
      client_tracker_map_[(*iter)->id()] =
          new ClientUsageTracker(this, *iter, type, special_storage_policy);
    }
  }
}

UsageTracker::~UsageTracker() {
  STLDeleteValues(&client_tracker_map_);
}

ClientUsageTracker* UsageTracker::GetClientTracker(QuotaClient::ID client_id) {
  ClientTrackerMap::iterator found = client_tracker_map_.find(client_id);
  if (found != client_tracker_map_.end())
    return found->second;
  return NULL;
}

void UsageTracker::GetGlobalLimitedUsage(const UsageCallback& callback) {
  if (global_usage_callbacks_.HasCallbacks()) {
    global_usage_callbacks_.Add(base::Bind(
        &DidGetGlobalUsageForLimitedGlobalUsage, callback));
    return;
  }

  if (!global_limited_usage_callbacks_.Add(callback))
    return;

  AccumulateInfo* info = new AccumulateInfo;
  // Calling GetGlobalLimitedUsage(accumulator) may synchronously
  // return if the usage is cached, which may in turn dispatch
  // the completion callback before we finish looping over
  // all clients (because info->pending_clients may reach 0
  // during the loop).
  // To avoid this, we add one more pending client as a sentinel
  // and fire the sentinel callback at the end.
  info->pending_clients = client_tracker_map_.size() + 1;
  UsageCallback accumulator = base::Bind(
      &UsageTracker::AccumulateClientGlobalLimitedUsage,
      weak_factory_.GetWeakPtr(), base::Owned(info));

  for (ClientTrackerMap::iterator iter = client_tracker_map_.begin();
       iter != client_tracker_map_.end();
       ++iter)
    iter->second->GetGlobalLimitedUsage(accumulator);

  // Fire the sentinel as we've now called GetGlobalUsage for all clients.
  accumulator.Run(0);
}

void UsageTracker::GetGlobalUsage(const GlobalUsageCallback& callback) {
  if (!global_usage_callbacks_.Add(callback))
    return;

  AccumulateInfo* info = new AccumulateInfo;
  // Calling GetGlobalUsage(accumulator) may synchronously
  // return if the usage is cached, which may in turn dispatch
  // the completion callback before we finish looping over
  // all clients (because info->pending_clients may reach 0
  // during the loop).
  // To avoid this, we add one more pending client as a sentinel
  // and fire the sentinel callback at the end.
  info->pending_clients = client_tracker_map_.size() + 1;
  GlobalUsageCallback accumulator = base::Bind(
      &UsageTracker::AccumulateClientGlobalUsage, weak_factory_.GetWeakPtr(),
      base::Owned(info));

  for (ClientTrackerMap::iterator iter = client_tracker_map_.begin();
       iter != client_tracker_map_.end();
       ++iter)
    iter->second->GetGlobalUsage(accumulator);

  // Fire the sentinel as we've now called GetGlobalUsage for all clients.
  accumulator.Run(0, 0);
}

void UsageTracker::GetHostUsage(const std::string& host,
                                const UsageCallback& callback) {
  if (!host_usage_callbacks_.Add(host, callback))
    return;

  AccumulateInfo* info = new AccumulateInfo;
  // Calling GetHostUsage(accumulator) may synchronously
  // return if the usage is cached, which may in turn dispatch
  // the completion callback before we finish looping over
  // all clients (because info->pending_clients may reach 0
  // during the loop).
  // To avoid this, we add one more pending client as a sentinel
  // and fire the sentinel callback at the end.
  info->pending_clients = client_tracker_map_.size() + 1;
  UsageCallback accumulator = base::Bind(
      &UsageTracker::AccumulateClientHostUsage, weak_factory_.GetWeakPtr(),
      base::Owned(info), host);

  for (ClientTrackerMap::iterator iter = client_tracker_map_.begin();
       iter != client_tracker_map_.end();
       ++iter)
    iter->second->GetHostUsage(host, accumulator);

  // Fire the sentinel as we've now called GetHostUsage for all clients.
  accumulator.Run(0);
}

void UsageTracker::UpdateUsageCache(
    QuotaClient::ID client_id, const GURL& origin, int64 delta) {
  ClientUsageTracker* client_tracker = GetClientTracker(client_id);
  DCHECK(client_tracker);
  client_tracker->UpdateUsageCache(origin, delta);
}

void UsageTracker::GetCachedHostsUsage(
    std::map<std::string, int64>* host_usage) const {
  DCHECK(host_usage);
  host_usage->clear();
  for (ClientTrackerMap::const_iterator iter = client_tracker_map_.begin();
       iter != client_tracker_map_.end(); ++iter) {
    iter->second->GetCachedHostsUsage(host_usage);
  }
}

void UsageTracker::GetCachedOrigins(std::set<GURL>* origins) const {
  DCHECK(origins);
  origins->clear();
  for (ClientTrackerMap::const_iterator iter = client_tracker_map_.begin();
       iter != client_tracker_map_.end(); ++iter) {
    iter->second->GetCachedOrigins(origins);
  }
}

void UsageTracker::SetUsageCacheEnabled(QuotaClient::ID client_id,
                                        const GURL& origin,
                                        bool enabled) {
  ClientUsageTracker* client_tracker = GetClientTracker(client_id);
  DCHECK(client_tracker);

  client_tracker->SetUsageCacheEnabled(origin, enabled);
}

void UsageTracker::AccumulateClientGlobalLimitedUsage(AccumulateInfo* info,
                                                      int64 limited_usage) {
  info->usage += limited_usage;
  if (--info->pending_clients)
    return;

  // All the clients have returned their usage data.  Dispatch the
  // pending callbacks.
  global_limited_usage_callbacks_.Run(MakeTuple(info->usage));
}

void UsageTracker::AccumulateClientGlobalUsage(AccumulateInfo* info,
                                               int64 usage,
                                               int64 unlimited_usage) {
  info->usage += usage;
  info->unlimited_usage += unlimited_usage;
  if (--info->pending_clients)
    return;

  // Defend against confusing inputs from clients.
  if (info->usage < 0)
    info->usage = 0;

  // TODO(michaeln): The unlimited number is not trustworthy, it
  // can get out of whack when apps are installed or uninstalled.
  if (info->unlimited_usage > info->usage)
    info->unlimited_usage = info->usage;
  else if (info->unlimited_usage < 0)
    info->unlimited_usage = 0;

  // All the clients have returned their usage data.  Dispatch the
  // pending callbacks.
  global_usage_callbacks_.Run(MakeTuple(info->usage, info->unlimited_usage));
}

void UsageTracker::AccumulateClientHostUsage(AccumulateInfo* info,
                                             const std::string& host,
                                             int64 usage) {
  info->usage += usage;
  if (--info->pending_clients)
    return;

  // Defend against confusing inputs from clients.
  if (info->usage < 0)
    info->usage = 0;

  // All the clients have returned their usage data.  Dispatch the
  // pending callbacks.
  host_usage_callbacks_.Run(host, MakeTuple(info->usage));
}

// ClientUsageTracker ----------------------------------------------------

ClientUsageTracker::ClientUsageTracker(
    UsageTracker* tracker, QuotaClient* client, StorageType type,
    SpecialStoragePolicy* special_storage_policy)
    : tracker_(tracker),
      client_(client),
      type_(type),
      global_limited_usage_(0),
      global_unlimited_usage_(0),
      global_usage_retrieved_(false),
      special_storage_policy_(special_storage_policy) {
  DCHECK(tracker_);
  DCHECK(client_);
  if (special_storage_policy_.get())
    special_storage_policy_->AddObserver(this);
}

ClientUsageTracker::~ClientUsageTracker() {
  if (special_storage_policy_.get())
    special_storage_policy_->RemoveObserver(this);
}

void ClientUsageTracker::GetGlobalLimitedUsage(const UsageCallback& callback) {
  if (!global_usage_retrieved_) {
    GetGlobalUsage(base::Bind(&DidGetGlobalUsageForLimitedGlobalUsage,
                              callback));
    return;
  }

  if (non_cached_limited_origins_by_host_.empty()) {
    callback.Run(global_limited_usage_);
    return;
  }

  AccumulateInfo* info = new AccumulateInfo;
  info->pending_jobs = non_cached_limited_origins_by_host_.size() + 1;
  UsageCallback accumulator = base::Bind(
      &ClientUsageTracker::AccumulateLimitedOriginUsage, AsWeakPtr(),
      base::Owned(info), callback);

  for (OriginSetByHost::iterator host_itr =
           non_cached_limited_origins_by_host_.begin();
       host_itr != non_cached_limited_origins_by_host_.end(); ++host_itr) {
    for (std::set<GURL>::iterator origin_itr = host_itr->second.begin();
         origin_itr != host_itr->second.end(); ++origin_itr)
      client_->GetOriginUsage(*origin_itr, type_, accumulator);
  }

  accumulator.Run(global_limited_usage_);
}

void ClientUsageTracker::GetGlobalUsage(const GlobalUsageCallback& callback) {
  if (global_usage_retrieved_ &&
      non_cached_limited_origins_by_host_.empty() &&
      non_cached_unlimited_origins_by_host_.empty()) {
    callback.Run(global_limited_usage_ + global_unlimited_usage_,
                 global_unlimited_usage_);
    return;
  }

  client_->GetOriginsForType(type_, base::Bind(
      &ClientUsageTracker::DidGetOriginsForGlobalUsage, AsWeakPtr(),
      callback));
}

void ClientUsageTracker::GetHostUsage(
    const std::string& host, const UsageCallback& callback) {
  if (ContainsKey(cached_hosts_, host) &&
      !ContainsKey(non_cached_limited_origins_by_host_, host) &&
      !ContainsKey(non_cached_unlimited_origins_by_host_, host)) {
    // TODO(kinuko): Drop host_usage_map_ cache periodically.
    callback.Run(GetCachedHostUsage(host));
    return;
  }

  if (!host_usage_accumulators_.Add(
          host, base::Bind(&DidGetHostUsage, callback)))
    return;
  client_->GetOriginsForHost(type_, host, base::Bind(
      &ClientUsageTracker::DidGetOriginsForHostUsage, AsWeakPtr(), host));
}

void ClientUsageTracker::UpdateUsageCache(
    const GURL& origin, int64 delta) {
  std::string host = net::GetHostOrSpecFromURL(origin);
  if (cached_hosts_.find(host) != cached_hosts_.end()) {
    if (!IsUsageCacheEnabledForOrigin(origin))
      return;

    cached_usage_by_host_[host][origin] += delta;
    if (IsStorageUnlimited(origin))
      global_unlimited_usage_ += delta;
    else
      global_limited_usage_ += delta;
    DCHECK_GE(cached_usage_by_host_[host][origin], 0);
    DCHECK_GE(global_limited_usage_, 0);
    return;
  }

  // We don't know about this host yet, so populate our cache for it.
  GetHostUsage(host, base::Bind(&NoopHostUsageCallback));
}

void ClientUsageTracker::GetCachedHostsUsage(
    std::map<std::string, int64>* host_usage) const {
  DCHECK(host_usage);
  for (HostUsageMap::const_iterator host_iter = cached_usage_by_host_.begin();
       host_iter != cached_usage_by_host_.end(); host_iter++) {
    const std::string& host = host_iter->first;
    (*host_usage)[host] += GetCachedHostUsage(host);
  }
}

void ClientUsageTracker::GetCachedOrigins(std::set<GURL>* origins) const {
  DCHECK(origins);
  for (HostUsageMap::const_iterator host_iter = cached_usage_by_host_.begin();
       host_iter != cached_usage_by_host_.end(); host_iter++) {
    const UsageMap& origin_map = host_iter->second;
    for (UsageMap::const_iterator origin_iter = origin_map.begin();
         origin_iter != origin_map.end(); origin_iter++) {
      origins->insert(origin_iter->first);
    }
  }
}

void ClientUsageTracker::SetUsageCacheEnabled(const GURL& origin,
                                              bool enabled) {
  std::string host = net::GetHostOrSpecFromURL(origin);
  if (!enabled) {
    // Erase |origin| from cache and subtract its usage.
    HostUsageMap::iterator found_host = cached_usage_by_host_.find(host);
    if (found_host != cached_usage_by_host_.end()) {
      UsageMap& cached_usage_for_host = found_host->second;

      UsageMap::iterator found = cached_usage_for_host.find(origin);
      if (found != cached_usage_for_host.end()) {
        int64 usage = found->second;
        UpdateUsageCache(origin, -usage);
        cached_usage_for_host.erase(found);
        if (cached_usage_for_host.empty()) {
          cached_usage_by_host_.erase(found_host);
          cached_hosts_.erase(host);
        }
      }
    }

    if (IsStorageUnlimited(origin))
      non_cached_unlimited_origins_by_host_[host].insert(origin);
    else
      non_cached_limited_origins_by_host_[host].insert(origin);
  } else {
    // Erase |origin| from |non_cached_origins_| and invalidate the usage cache
    // for the host.
    if (EraseOriginFromOriginSet(&non_cached_limited_origins_by_host_,
                                 host, origin) ||
        EraseOriginFromOriginSet(&non_cached_unlimited_origins_by_host_,
                                 host, origin)) {
      cached_hosts_.erase(host);
      global_usage_retrieved_ = false;
    }
  }
}

void ClientUsageTracker::AccumulateLimitedOriginUsage(
    AccumulateInfo* info,
    const UsageCallback& callback,
    int64 usage) {
  info->limited_usage += usage;
  if (--info->pending_jobs)
    return;

  callback.Run(info->limited_usage);
}

void ClientUsageTracker::DidGetOriginsForGlobalUsage(
    const GlobalUsageCallback& callback,
    const std::set<GURL>& origins) {
  OriginSetByHost origins_by_host;
  for (std::set<GURL>::const_iterator itr = origins.begin();
       itr != origins.end(); ++itr)
    origins_by_host[net::GetHostOrSpecFromURL(*itr)].insert(*itr);

  AccumulateInfo* info = new AccumulateInfo;
  // Getting host usage may synchronously return the result if the usage is
  // cached, which may in turn dispatch the completion callback before we finish
  // looping over all hosts (because info->pending_jobs may reach 0 during the
  // loop).  To avoid this, we add one more pending host as a sentinel and
  // fire the sentinel callback at the end.
  info->pending_jobs = origins_by_host.size() + 1;
  HostUsageAccumulator accumulator =
      base::Bind(&ClientUsageTracker::AccumulateHostUsage, AsWeakPtr(),
                 base::Owned(info), callback);

  for (OriginSetByHost::iterator itr = origins_by_host.begin();
       itr != origins_by_host.end(); ++itr) {
    if (host_usage_accumulators_.Add(itr->first, accumulator))
      GetUsageForOrigins(itr->first, itr->second);
  }

  // Fire the sentinel as we've now called GetUsageForOrigins for all clients.
  accumulator.Run(0, 0);
}

void ClientUsageTracker::AccumulateHostUsage(
    AccumulateInfo* info,
    const GlobalUsageCallback& callback,
    int64 limited_usage,
    int64 unlimited_usage) {
  info->limited_usage += limited_usage;
  info->unlimited_usage += unlimited_usage;
  if (--info->pending_jobs)
    return;

  DCHECK_GE(info->limited_usage, 0);
  DCHECK_GE(info->unlimited_usage, 0);

  global_usage_retrieved_ = true;
  callback.Run(info->limited_usage + info->unlimited_usage,
               info->unlimited_usage);
}

void ClientUsageTracker::DidGetOriginsForHostUsage(
    const std::string& host,
    const std::set<GURL>& origins) {
  GetUsageForOrigins(host, origins);
}

void ClientUsageTracker::GetUsageForOrigins(
    const std::string& host,
    const std::set<GURL>& origins) {
  AccumulateInfo* info = new AccumulateInfo;
  // Getting origin usage may synchronously return the result if the usage is
  // cached, which may in turn dispatch the completion callback before we finish
  // looping over all origins (because info->pending_jobs may reach 0 during the
  // loop).  To avoid this, we add one more pending origin as a sentinel and
  // fire the sentinel callback at the end.
  info->pending_jobs = origins.size() + 1;
  OriginUsageAccumulator accumulator =
      base::Bind(&ClientUsageTracker::AccumulateOriginUsage, AsWeakPtr(),
                 base::Owned(info), host);

  for (std::set<GURL>::const_iterator itr = origins.begin();
       itr != origins.end(); ++itr) {
    DCHECK_EQ(host, net::GetHostOrSpecFromURL(*itr));

    int64 origin_usage = 0;
    if (GetCachedOriginUsage(*itr, &origin_usage)) {
      accumulator.Run(*itr, origin_usage);
    } else {
      client_->GetOriginUsage(*itr, type_, base::Bind(
          &DidGetOriginUsage, accumulator, *itr));
    }
  }

  // Fire the sentinel as we've now called GetOriginUsage for all clients.
  accumulator.Run(GURL(), 0);
}

void ClientUsageTracker::AccumulateOriginUsage(AccumulateInfo* info,
                                               const std::string& host,
                                               const GURL& origin,
                                               int64 usage) {
  if (!origin.is_empty()) {
    if (usage < 0)
      usage = 0;

    if (IsStorageUnlimited(origin))
      info->unlimited_usage += usage;
    else
      info->limited_usage += usage;
    if (IsUsageCacheEnabledForOrigin(origin))
      AddCachedOrigin(origin, usage);
  }
  if (--info->pending_jobs)
    return;

  AddCachedHost(host);
  host_usage_accumulators_.Run(
      host, MakeTuple(info->limited_usage, info->unlimited_usage));
}

void ClientUsageTracker::AddCachedOrigin(
    const GURL& origin, int64 new_usage) {
  DCHECK(IsUsageCacheEnabledForOrigin(origin));

  std::string host = net::GetHostOrSpecFromURL(origin);
  int64* usage = &cached_usage_by_host_[host][origin];
  int64 delta = new_usage - *usage;
  *usage = new_usage;
  if (delta) {
    if (IsStorageUnlimited(origin))
      global_unlimited_usage_ += delta;
    else
      global_limited_usage_ += delta;
  }
  DCHECK_GE(*usage, 0);
  DCHECK_GE(global_limited_usage_, 0);
}

void ClientUsageTracker::AddCachedHost(const std::string& host) {
  cached_hosts_.insert(host);
}

int64 ClientUsageTracker::GetCachedHostUsage(const std::string& host) const {
  HostUsageMap::const_iterator found = cached_usage_by_host_.find(host);
  if (found == cached_usage_by_host_.end())
    return 0;

  int64 usage = 0;
  const UsageMap& map = found->second;
  for (UsageMap::const_iterator iter = map.begin();
       iter != map.end(); ++iter) {
    usage += iter->second;
  }
  return usage;
}

bool ClientUsageTracker::GetCachedOriginUsage(
    const GURL& origin,
    int64* usage) const {
  std::string host = net::GetHostOrSpecFromURL(origin);
  HostUsageMap::const_iterator found_host = cached_usage_by_host_.find(host);
  if (found_host == cached_usage_by_host_.end())
    return false;

  UsageMap::const_iterator found = found_host->second.find(origin);
  if (found == found_host->second.end())
    return false;

  DCHECK(IsUsageCacheEnabledForOrigin(origin));
  *usage = found->second;
  return true;
}

bool ClientUsageTracker::IsUsageCacheEnabledForOrigin(
    const GURL& origin) const {
  std::string host = net::GetHostOrSpecFromURL(origin);
  return !OriginSetContainsOrigin(non_cached_limited_origins_by_host_,
                                  host, origin) &&
      !OriginSetContainsOrigin(non_cached_unlimited_origins_by_host_,
                               host, origin);
}

void ClientUsageTracker::OnGranted(const GURL& origin,
                                   int change_flags) {
  DCHECK(CalledOnValidThread());
  if (change_flags & SpecialStoragePolicy::STORAGE_UNLIMITED) {
    int64 usage = 0;
    if (GetCachedOriginUsage(origin, &usage)) {
      global_unlimited_usage_ += usage;
      global_limited_usage_ -= usage;
    }

    std::string host = net::GetHostOrSpecFromURL(origin);
    if (EraseOriginFromOriginSet(&non_cached_limited_origins_by_host_,
                                 host, origin))
      non_cached_unlimited_origins_by_host_[host].insert(origin);
  }
}

void ClientUsageTracker::OnRevoked(const GURL& origin,
                                   int change_flags) {
  DCHECK(CalledOnValidThread());
  if (change_flags & SpecialStoragePolicy::STORAGE_UNLIMITED) {
    int64 usage = 0;
    if (GetCachedOriginUsage(origin, &usage)) {
      global_unlimited_usage_ -= usage;
      global_limited_usage_ += usage;
    }

    std::string host = net::GetHostOrSpecFromURL(origin);
    if (EraseOriginFromOriginSet(&non_cached_unlimited_origins_by_host_,
                                 host, origin))
      non_cached_limited_origins_by_host_[host].insert(origin);
  }
}

void ClientUsageTracker::OnCleared() {
  DCHECK(CalledOnValidThread());
  global_limited_usage_ += global_unlimited_usage_;
  global_unlimited_usage_ = 0;

  for (OriginSetByHost::const_iterator host_itr =
           non_cached_unlimited_origins_by_host_.begin();
       host_itr != non_cached_unlimited_origins_by_host_.end();
       ++host_itr) {
    for (std::set<GURL>::const_iterator origin_itr = host_itr->second.begin();
         origin_itr != host_itr->second.end();
         ++origin_itr)
      non_cached_limited_origins_by_host_[host_itr->first].insert(*origin_itr);
  }
  non_cached_unlimited_origins_by_host_.clear();
}

bool ClientUsageTracker::IsStorageUnlimited(const GURL& origin) const {
  if (type_ == kStorageTypeSyncable)
    return false;
  return special_storage_policy_.get() &&
         special_storage_policy_->IsStorageUnlimited(origin);
}

}  // namespace quota