// Copyright (c) 2009 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/websockets/websocket_throttle.h"

#include <string>

#include "base/message_loop.h"
#include "base/ref_counted.h"
#include "base/singleton.h"
#include "base/string_util.h"
#include "net/base/io_buffer.h"
#include "net/base/sys_addrinfo.h"
#include "net/socket_stream/socket_stream.h"

namespace net {

static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) {
  switch (addrinfo->ai_family) {
    case AF_INET: {
      const struct sockaddr_in* const addr =
          reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr);
      return StringPrintf("%d:%s",
                          addrinfo->ai_family,
                          HexEncode(&addr->sin_addr, 4).c_str());
      }
    case AF_INET6: {
      const struct sockaddr_in6* const addr6 =
          reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr);
      return StringPrintf("%d:%s",
                          addrinfo->ai_family,
                          HexEncode(&addr6->sin6_addr,
                                    sizeof(addr6->sin6_addr)).c_str());
      }
    default:
      return StringPrintf("%d:%s",
                          addrinfo->ai_family,
                          HexEncode(addrinfo->ai_addr,
                                    addrinfo->ai_addrlen).c_str());
  }
}

// State for WebSocket protocol on each SocketStream.
// This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName.
// This is alive between connection starts and handshake is finished.
// In this class, it doesn't check actual handshake finishes, but only checks
// end of header is found in read data.
class WebSocketThrottle::WebSocketState : public SocketStream::UserData {
 public:
  explicit WebSocketState(const AddressList& addrs)
      : address_list_(addrs),
        callback_(NULL),
        waiting_(false),
        handshake_finished_(false),
        buffer_(NULL) {
  }
  ~WebSocketState() {}

  int OnStartOpenConnection(CompletionCallback* callback) {
    DCHECK(!callback_);
    if (!waiting_)
      return OK;
    callback_ = callback;
    return ERR_IO_PENDING;
  }

  int OnRead(const char* data, int len, CompletionCallback* callback) {
    DCHECK(!waiting_);
    DCHECK(!callback_);
    DCHECK(!handshake_finished_);
    static const int kBufferSize = 8129;

    if (!buffer_) {
      // Fast path.
      int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0);
      if (eoh > 0) {
        handshake_finished_ = true;
        return OK;
      }
      buffer_ = new GrowableIOBuffer();
      buffer_->SetCapacity(kBufferSize);
    } else if (buffer_->RemainingCapacity() < len) {
      buffer_->SetCapacity(buffer_->capacity() + kBufferSize);
    }
    memcpy(buffer_->data(), data, len);
    buffer_->set_offset(buffer_->offset() + len);

    int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(),
                                           buffer_->offset(), 0);
    handshake_finished_ = (eoh > 0);
    return OK;
  }

  const AddressList& address_list() const { return address_list_; }
  void SetWaiting() { waiting_ = true; }
  bool IsWaiting() const { return waiting_; }
  bool HandshakeFinished() const { return handshake_finished_; }
  void Wakeup() {
    waiting_ = false;
    // We wrap |callback_| to keep this alive while this is released.
    scoped_refptr<CompletionCallbackRunner> runner =
        new CompletionCallbackRunner(callback_);
    callback_ = NULL;
    MessageLoopForIO::current()->PostTask(
        FROM_HERE,
        NewRunnableMethod(runner.get(),
                          &CompletionCallbackRunner::Run));
  }

  static const char* kKeyName;

 private:
  class CompletionCallbackRunner
      : public base::RefCountedThreadSafe<CompletionCallbackRunner> {
   public:
    explicit CompletionCallbackRunner(CompletionCallback* callback)
        : callback_(callback) {
      DCHECK(callback_);
    }
    void Run() {
      callback_->Run(OK);
    }
   private:
    friend class base::RefCountedThreadSafe<CompletionCallbackRunner>;

    virtual ~CompletionCallbackRunner() {}

    CompletionCallback* callback_;

    DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner);
  };

  const AddressList& address_list_;

  CompletionCallback* callback_;
  // True if waiting another websocket connection is established.
  // False if the websocket is performing handshaking.
  bool waiting_;

  // True if the websocket handshake is completed.
  // If true, it will be removed from queue and deleted from the SocketStream
  // UserData soon.
  bool handshake_finished_;

  // Buffer for read data to check handshake response message.
  scoped_refptr<GrowableIOBuffer> buffer_;

  DISALLOW_COPY_AND_ASSIGN(WebSocketState);
};

const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState";

WebSocketThrottle::WebSocketThrottle() {
  SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this);
  SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this);
}

WebSocketThrottle::~WebSocketThrottle() {
  DCHECK(queue_.empty());
  DCHECK(addr_map_.empty());
}

int WebSocketThrottle::OnStartOpenConnection(
    SocketStream* socket, CompletionCallback* callback) {
  WebSocketState* state = new WebSocketState(socket->address_list());
  PutInQueue(socket, state);
  return state->OnStartOpenConnection(callback);
}

int WebSocketThrottle::OnRead(SocketStream* socket,
                              const char* data, int len,
                              CompletionCallback* callback) {
  WebSocketState* state = static_cast<WebSocketState*>(
      socket->GetUserData(WebSocketState::kKeyName));
  // If no state, handshake was already completed. Do nothing.
  if (!state)
    return OK;

  int result = state->OnRead(data, len, callback);
  if (state->HandshakeFinished()) {
    RemoveFromQueue(socket, state);
    WakeupSocketIfNecessary();
  }
  return result;
}

int WebSocketThrottle::OnWrite(SocketStream* socket,
                               const char* data, int len,
                               CompletionCallback* callback) {
  // Do nothing.
  return OK;
}

void WebSocketThrottle::OnClose(SocketStream* socket) {
  WebSocketState* state = static_cast<WebSocketState*>(
      socket->GetUserData(WebSocketState::kKeyName));
  if (!state)
    return;
  RemoveFromQueue(socket, state);
  WakeupSocketIfNecessary();
}

void WebSocketThrottle::PutInQueue(SocketStream* socket,
                                   WebSocketState* state) {
  socket->SetUserData(WebSocketState::kKeyName, state);
  queue_.push_back(state);
  const AddressList& address_list = socket->address_list();
  for (const struct addrinfo* addrinfo = address_list.head();
       addrinfo != NULL;
       addrinfo = addrinfo->ai_next) {
    std::string addrkey = AddrinfoToHashkey(addrinfo);
    ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
    if (iter == addr_map_.end()) {
      ConnectingQueue* queue = new ConnectingQueue();
      queue->push_back(state);
      addr_map_[addrkey] = queue;
    } else {
      iter->second->push_back(state);
      state->SetWaiting();
    }
  }
}

void WebSocketThrottle::RemoveFromQueue(SocketStream* socket,
                                        WebSocketState* state) {
  const AddressList& address_list = socket->address_list();
  for (const struct addrinfo* addrinfo = address_list.head();
       addrinfo != NULL;
       addrinfo = addrinfo->ai_next) {
    std::string addrkey = AddrinfoToHashkey(addrinfo);
    ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
    DCHECK(iter != addr_map_.end());
    ConnectingQueue* queue = iter->second;
    DCHECK(state == queue->front());
    queue->pop_front();
    if (queue->empty()) {
      delete queue;
      addr_map_.erase(iter);
    }
  }
  for (ConnectingQueue::iterator iter = queue_.begin();
       iter != queue_.end();
       ++iter) {
    if (*iter == state) {
      queue_.erase(iter);
      break;
    }
  }
  socket->SetUserData(WebSocketState::kKeyName, NULL);
}

void WebSocketThrottle::WakeupSocketIfNecessary() {
  for (ConnectingQueue::iterator iter = queue_.begin();
       iter != queue_.end();
       ++iter) {
    WebSocketState* state = *iter;
    if (!state->IsWaiting())
      continue;

    bool should_wakeup = true;
    const AddressList& address_list = state->address_list();
    for (const struct addrinfo* addrinfo = address_list.head();
         addrinfo != NULL;
         addrinfo = addrinfo->ai_next) {
      std::string addrkey = AddrinfoToHashkey(addrinfo);
      ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
      DCHECK(iter != addr_map_.end());
      ConnectingQueue* queue = iter->second;
      if (state != queue->front()) {
        should_wakeup = false;
        break;
      }
    }
    if (should_wakeup)
      state->Wakeup();
  }
}

/* static */
void WebSocketThrottle::Init() {
  Singleton<WebSocketThrottle>::get();
}

}  // namespace net