// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/websockets/websocket_stream.h"

#include "base/logging.h"
#include "base/memory/scoped_ptr.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_status_code.h"
#include "net/url_request/url_request.h"
#include "net/url_request/url_request_context.h"
#include "net/websockets/websocket_errors.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "net/websockets/websocket_handshake_stream_base.h"
#include "net/websockets/websocket_handshake_stream_create_helper.h"
#include "net/websockets/websocket_test_util.h"
#include "url/gurl.h"

namespace net {
namespace {

class StreamRequestImpl;

class Delegate : public URLRequest::Delegate {
 public:
  explicit Delegate(StreamRequestImpl* owner) : owner_(owner) {}
  virtual ~Delegate() {}

  // Implementation of URLRequest::Delegate methods.
  virtual void OnResponseStarted(URLRequest* request) OVERRIDE;

  virtual void OnAuthRequired(URLRequest* request,
                              AuthChallengeInfo* auth_info) OVERRIDE;

  virtual void OnCertificateRequested(URLRequest* request,
                                      SSLCertRequestInfo* cert_request_info)
      OVERRIDE;

  virtual void OnSSLCertificateError(URLRequest* request,
                                     const SSLInfo& ssl_info,
                                     bool fatal) OVERRIDE;

  virtual void OnReadCompleted(URLRequest* request, int bytes_read) OVERRIDE;

 private:
  StreamRequestImpl* owner_;
};

class StreamRequestImpl : public WebSocketStreamRequest {
 public:
  StreamRequestImpl(
      const GURL& url,
      const URLRequestContext* context,
      scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
      WebSocketHandshakeStreamCreateHelper* create_helper)
      : delegate_(new Delegate(this)),
        url_request_(url, DEFAULT_PRIORITY, delegate_.get(), context),
        connect_delegate_(connect_delegate.Pass()),
        create_helper_(create_helper) {}

  // Destroying this object destroys the URLRequest, which cancels the request
  // and so terminates the handshake if it is incomplete.
  virtual ~StreamRequestImpl() {}

  URLRequest* url_request() { return &url_request_; }

  void PerformUpgrade() {
    connect_delegate_->OnSuccess(create_helper_->stream()->Upgrade());
  }

  void ReportFailure() {
    connect_delegate_->OnFailure(kWebSocketErrorAbnormalClosure);
  }

 private:
  // |delegate_| needs to be declared before |url_request_| so that it gets
  // initialised first.
  scoped_ptr<Delegate> delegate_;

  // Deleting the StreamRequestImpl object deletes this URLRequest object,
  // cancelling the whole connection.
  URLRequest url_request_;

  scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate_;

  // Owned by the URLRequest.
  WebSocketHandshakeStreamCreateHelper* create_helper_;
};

void Delegate::OnResponseStarted(URLRequest* request) {
  switch (request->GetResponseCode()) {
    case HTTP_SWITCHING_PROTOCOLS:
      owner_->PerformUpgrade();
      return;

    case HTTP_UNAUTHORIZED:
    case HTTP_PROXY_AUTHENTICATION_REQUIRED:
      return;

    default:
      owner_->ReportFailure();
  }
}

void Delegate::OnAuthRequired(URLRequest* request,
                              AuthChallengeInfo* auth_info) {
  request->CancelAuth();
}

void Delegate::OnCertificateRequested(URLRequest* request,
                                      SSLCertRequestInfo* cert_request_info) {
  request->ContinueWithCertificate(NULL);
}

void Delegate::OnSSLCertificateError(URLRequest* request,
                                     const SSLInfo& ssl_info,
                                     bool fatal) {
  request->Cancel();
}

void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
  NOTREACHED();
}

// Internal implementation of CreateAndConnectStream and
// CreateAndConnectStreamForTesting.
scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamWithCreateHelper(
    const GURL& socket_url,
    scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
    const GURL& origin,
    URLRequestContext* url_request_context,
    const BoundNetLog& net_log,
    scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) {
  scoped_ptr<StreamRequestImpl> request(
      new StreamRequestImpl(socket_url,
                            url_request_context,
                            connect_delegate.Pass(),
                            create_helper.get()));
  HttpRequestHeaders headers;
  headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
  headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
  headers.SetHeader(HttpRequestHeaders::kOrigin, origin.spec());
  // TODO(ricea): Move the version number to websocket_handshake_constants.h
  headers.SetHeader(websockets::kSecWebSocketVersion,
                    websockets::kSupportedVersion);
  request->url_request()->SetExtraRequestHeaders(headers);
  request->url_request()->SetUserData(
      WebSocketHandshakeStreamBase::CreateHelper::DataKey(),
      create_helper.release());
  request->url_request()->SetLoadFlags(LOAD_DISABLE_CACHE |
                                       LOAD_DO_NOT_PROMPT_FOR_LOGIN);
  request->url_request()->Start();
  return request.PassAs<WebSocketStreamRequest>();
}

}  // namespace

WebSocketStreamRequest::~WebSocketStreamRequest() {}

WebSocketStream::WebSocketStream() {}
WebSocketStream::~WebSocketStream() {}

WebSocketStream::ConnectDelegate::~ConnectDelegate() {}

scoped_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
    const GURL& socket_url,
    const std::vector<std::string>& requested_subprotocols,
    const GURL& origin,
    URLRequestContext* url_request_context,
    const BoundNetLog& net_log,
    scoped_ptr<ConnectDelegate> connect_delegate) {
  scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper(
      new WebSocketHandshakeStreamCreateHelper(requested_subprotocols));
  return CreateAndConnectStreamWithCreateHelper(socket_url,
                                                create_helper.Pass(),
                                                origin,
                                                url_request_context,
                                                net_log,
                                                connect_delegate.Pass());
}

// This is declared in websocket_test_util.h.
scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting(
      const GURL& socket_url,
      scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
      const GURL& origin,
      URLRequestContext* url_request_context,
      const BoundNetLog& net_log,
      scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) {
  return CreateAndConnectStreamWithCreateHelper(socket_url,
                                                create_helper.Pass(),
                                                origin,
                                                url_request_context,
                                                net_log,
                                                connect_delegate.Pass());
}

}  // namespace net