// Copyright (c) 2010 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include <string> #include <vector> #include "base/callback.h" #include "base/utf_string_conversions.h" #include "net/base/auth.h" #include "net/base/mock_host_resolver.h" #include "net/base/net_log.h" #include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" #include "net/socket/socket_test_util.h" #include "net/socket_stream/socket_stream.h" #include "net/url_request/url_request_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" struct SocketStreamEvent { enum EventType { EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE, EVENT_AUTH_REQUIRED, }; SocketStreamEvent(EventType type, net::SocketStream* socket_stream, int num, const std::string& str, net::AuthChallengeInfo* auth_challenge_info) : event_type(type), socket(socket_stream), number(num), data(str), auth_info(auth_challenge_info) {} EventType event_type; net::SocketStream* socket; int number; std::string data; scoped_refptr<net::AuthChallengeInfo> auth_info; }; class SocketStreamEventRecorder : public net::SocketStream::Delegate { public: explicit SocketStreamEventRecorder(net::CompletionCallback* callback) : on_connected_(NULL), on_sent_data_(NULL), on_received_data_(NULL), on_close_(NULL), on_auth_required_(NULL), callback_(callback) {} virtual ~SocketStreamEventRecorder() { delete on_connected_; delete on_sent_data_; delete on_received_data_; delete on_close_; delete on_auth_required_; } void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) { on_connected_ = callback; } void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) { on_sent_data_ = callback; } void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) { on_received_data_ = callback; } void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) { on_close_ = callback; } void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) { on_auth_required_ = callback; } virtual void OnConnected(net::SocketStream* socket, int num_pending_send_allowed) { events_.push_back( SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED, socket, num_pending_send_allowed, std::string(), NULL)); if (on_connected_) on_connected_->Run(&events_.back()); } virtual void OnSentData(net::SocketStream* socket, int amount_sent) { events_.push_back( SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA, socket, amount_sent, std::string(), NULL)); if (on_sent_data_) on_sent_data_->Run(&events_.back()); } virtual void OnReceivedData(net::SocketStream* socket, const char* data, int len) { events_.push_back( SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA, socket, len, std::string(data, len), NULL)); if (on_received_data_) on_received_data_->Run(&events_.back()); } virtual void OnClose(net::SocketStream* socket) { events_.push_back( SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE, socket, 0, std::string(), NULL)); if (on_close_) on_close_->Run(&events_.back()); if (callback_) callback_->Run(net::OK); } virtual void OnAuthRequired(net::SocketStream* socket, net::AuthChallengeInfo* auth_info) { events_.push_back( SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED, socket, 0, std::string(), auth_info)); if (on_auth_required_) on_auth_required_->Run(&events_.back()); } void DoClose(SocketStreamEvent* event) { event->socket->Close(); } void DoRestartWithAuth(SocketStreamEvent* event) { VLOG(1) << "RestartWithAuth username=" << username_ << " password=" << password_; event->socket->RestartWithAuth(username_, password_); } void SetAuthInfo(const string16& username, const string16& password) { username_ = username; password_ = password; } const std::vector<SocketStreamEvent>& GetSeenEvents() const { return events_; } private: std::vector<SocketStreamEvent> events_; Callback1<SocketStreamEvent*>::Type* on_connected_; Callback1<SocketStreamEvent*>::Type* on_sent_data_; Callback1<SocketStreamEvent*>::Type* on_received_data_; Callback1<SocketStreamEvent*>::Type* on_close_; Callback1<SocketStreamEvent*>::Type* on_auth_required_; net::CompletionCallback* callback_; string16 username_; string16 password_; DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder); }; namespace net { class SocketStreamTest : public PlatformTest { public: virtual ~SocketStreamTest() {} virtual void SetUp() { mock_socket_factory_.reset(); handshake_request_ = kWebSocketHandshakeRequest; handshake_response_ = kWebSocketHandshakeResponse; } virtual void TearDown() { mock_socket_factory_.reset(); } virtual void SetWebSocketHandshakeMessage( const char* request, const char* response) { handshake_request_ = request; handshake_response_ = response; } virtual void AddWebSocketMessage(const std::string& message) { messages_.push_back(message); } virtual MockClientSocketFactory* GetMockClientSocketFactory() { mock_socket_factory_.reset(new MockClientSocketFactory); return mock_socket_factory_.get(); } virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) { event->socket->SendData( handshake_request_.data(), handshake_request_.size()); } virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) { // handshake response received. for (size_t i = 0; i < messages_.size(); i++) { std::vector<char> frame; frame.push_back('\0'); frame.insert(frame.end(), messages_[i].begin(), messages_[i].end()); frame.push_back('\xff'); EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size())); } // Actual ClientSocket close must happen after all frames queued by // SendData above are sent out. event->socket->Close(); } static const char* kWebSocketHandshakeRequest; static const char* kWebSocketHandshakeResponse; private: std::string handshake_request_; std::string handshake_response_; std::vector<std::string> messages_; scoped_ptr<MockClientSocketFactory> mock_socket_factory_; }; const char* SocketStreamTest::kWebSocketHandshakeRequest = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" "Sec-WebSocket-Protocol: sample\r\n" "Upgrade: WebSocket\r\n" "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" "Origin: http://example.com\r\n" "\r\n" "^n:ds[4U"; const char* SocketStreamTest::kWebSocketHandshakeResponse = "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Origin: http://example.com\r\n" "Sec-WebSocket-Location: ws://example.com/demo\r\n" "Sec-WebSocket-Protocol: sample\r\n" "\r\n" "8jKS'y:G*Co,Wxa-"; TEST_F(SocketStreamTest, CloseFlushPendingWrite) { TestCompletionCallback callback; scoped_ptr<SocketStreamEventRecorder> delegate( new SocketStreamEventRecorder(&callback)); // Necessary for NewCallback. SocketStreamTest* test = this; delegate->SetOnConnected(NewCallback( test, &SocketStreamTest::DoSendWebSocketHandshake)); delegate->SetOnReceivedData(NewCallback( test, &SocketStreamTest::DoCloseFlushPendingWriteTest)); MockHostResolver host_resolver; scoped_refptr<SocketStream> socket_stream( new SocketStream(GURL("ws://example.com/demo"), delegate.get())); socket_stream->set_context(new TestURLRequestContext()); socket_stream->SetHostResolver(&host_resolver); MockWrite data_writes[] = { MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), MockWrite(true, "\0message1\xff", 10), MockWrite(true, "\0message2\xff", 10) }; MockRead data_reads[] = { MockRead(SocketStreamTest::kWebSocketHandshakeResponse), // Server doesn't close the connection after handshake. MockRead(true, ERR_IO_PENDING) }; AddWebSocketMessage("message1"); AddWebSocketMessage("message2"); scoped_refptr<DelayedSocketData> data_provider( new DelayedSocketData(1, data_reads, arraysize(data_reads), data_writes, arraysize(data_writes))); MockClientSocketFactory* mock_socket_factory = GetMockClientSocketFactory(); mock_socket_factory->AddSocketDataProvider(data_provider.get()); socket_stream->SetClientSocketFactory(mock_socket_factory); socket_stream->Connect(); callback.WaitForResult(); const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); EXPECT_EQ(6U, events.size()); EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type); } TEST_F(SocketStreamTest, BasicAuthProxy) { MockClientSocketFactory mock_socket_factory; MockWrite data_writes1[] = { MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" "Host: example.com\r\n" "Proxy-Connection: keep-alive\r\n\r\n"), }; MockRead data_reads1[] = { MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), MockRead("\r\n"), }; StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1), data_writes1, arraysize(data_writes1)); mock_socket_factory.AddSocketDataProvider(&data1); MockWrite data_writes2[] = { MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" "Host: example.com\r\n" "Proxy-Connection: keep-alive\r\n" "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), }; MockRead data_reads2[] = { MockRead("HTTP/1.1 200 Connection Established\r\n"), MockRead("Proxy-agent: Apache/2.2.8\r\n"), MockRead("\r\n"), // SocketStream::DoClose is run asynchronously. Socket can be read after // "\r\n". We have to give ERR_IO_PENDING to SocketStream then to indicate // server doesn't close the connection. MockRead(true, ERR_IO_PENDING) }; StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2), data_writes2, arraysize(data_writes2)); mock_socket_factory.AddSocketDataProvider(&data2); TestCompletionCallback callback; scoped_ptr<SocketStreamEventRecorder> delegate( new SocketStreamEventRecorder(&callback)); delegate->SetOnConnected(NewCallback(delegate.get(), &SocketStreamEventRecorder::DoClose)); delegate->SetAuthInfo(ASCIIToUTF16("foo"), ASCIIToUTF16("bar")); delegate->SetOnAuthRequired( NewCallback(delegate.get(), &SocketStreamEventRecorder::DoRestartWithAuth)); scoped_refptr<SocketStream> socket_stream( new SocketStream(GURL("ws://example.com/demo"), delegate.get())); socket_stream->set_context(new TestURLRequestContext("myproxy:70")); MockHostResolver host_resolver; socket_stream->SetHostResolver(&host_resolver); socket_stream->SetClientSocketFactory(&mock_socket_factory); socket_stream->Connect(); callback.WaitForResult(); const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); EXPECT_EQ(3U, events.size()); EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type); // TODO(eroman): Add back NetLogTest here... } } // namespace net