// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include <string> #include <vector> #include "base/bind.h" #include "base/compiler_specific.h" #include "base/location.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/message_loop/message_loop.h" #include "base/message_loop/message_loop_proxy.h" #include "base/run_loop.h" #include "base/single_thread_task_runner.h" #include "base/threading/thread.h" #include "base/time/time.h" #include "chrome/test/chromedriver/net/test_http_server.h" #include "chrome/test/chromedriver/net/websocket.h" #include "net/url_request/url_request_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" namespace { void OnConnectFinished(base::RunLoop* run_loop, int* save_error, int error) { *save_error = error; run_loop->Quit(); } void RunPending(base::MessageLoop* loop) { base::RunLoop run_loop; loop->PostTask(FROM_HERE, run_loop.QuitClosure()); run_loop.Run(); } class Listener : public WebSocketListener { public: explicit Listener(const std::vector<std::string>& messages) : messages_(messages) {} virtual ~Listener() { EXPECT_TRUE(messages_.empty()); } virtual void OnMessageReceived(const std::string& message) OVERRIDE { ASSERT_TRUE(messages_.size()); EXPECT_EQ(messages_[0], message); messages_.erase(messages_.begin()); if (messages_.empty()) base::MessageLoop::current()->Quit(); } virtual void OnClose() OVERRIDE { EXPECT_TRUE(false); } private: std::vector<std::string> messages_; }; class CloseListener : public WebSocketListener { public: explicit CloseListener(base::RunLoop* run_loop) : run_loop_(run_loop) {} virtual ~CloseListener() { EXPECT_FALSE(run_loop_); } virtual void OnMessageReceived(const std::string& message) OVERRIDE {} virtual void OnClose() OVERRIDE { EXPECT_TRUE(run_loop_); if (run_loop_) run_loop_->Quit(); run_loop_ = NULL; } private: base::RunLoop* run_loop_; }; class WebSocketTest : public testing::Test { public: WebSocketTest() {} virtual ~WebSocketTest() {} virtual void SetUp() OVERRIDE { ASSERT_TRUE(server_.Start()); } virtual void TearDown() OVERRIDE { server_.Stop(); } protected: scoped_ptr<WebSocket> CreateWebSocket(const GURL& url, WebSocketListener* listener) { int error; scoped_ptr<WebSocket> sock(new WebSocket(url, listener)); base::RunLoop run_loop; sock->Connect(base::Bind(&OnConnectFinished, &run_loop, &error)); loop_.PostDelayedTask( FROM_HERE, run_loop.QuitClosure(), base::TimeDelta::FromSeconds(10)); run_loop.Run(); if (error == net::OK) return sock.Pass(); return scoped_ptr<WebSocket>(); } scoped_ptr<WebSocket> CreateConnectedWebSocket(WebSocketListener* listener) { return CreateWebSocket(server_.web_socket_url(), listener); } void SendReceive(const std::vector<std::string>& messages) { Listener listener(messages); scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); ASSERT_TRUE(sock); for (size_t i = 0; i < messages.size(); ++i) { ASSERT_TRUE(sock->Send(messages[i])); } base::RunLoop run_loop; loop_.PostDelayedTask( FROM_HERE, run_loop.QuitClosure(), base::TimeDelta::FromSeconds(10)); run_loop.Run(); } base::MessageLoopForIO loop_; TestHttpServer server_; }; } // namespace TEST_F(WebSocketTest, CreateDestroy) { CloseListener listener(NULL); WebSocket sock(GURL("ws://127.0.0.1:2222"), &listener); } TEST_F(WebSocketTest, Connect) { CloseListener listener(NULL); ASSERT_TRUE(CreateWebSocket(server_.web_socket_url(), &listener)); RunPending(&loop_); ASSERT_TRUE(server_.WaitForConnectionsToClose()); } TEST_F(WebSocketTest, ConnectNoServer) { CloseListener listener(NULL); ASSERT_FALSE(CreateWebSocket(GURL("ws://127.0.0.1:33333"), NULL)); } TEST_F(WebSocketTest, Connect404) { server_.SetRequestAction(TestHttpServer::kNotFound); CloseListener listener(NULL); ASSERT_FALSE(CreateWebSocket(server_.web_socket_url(), NULL)); RunPending(&loop_); ASSERT_TRUE(server_.WaitForConnectionsToClose()); } TEST_F(WebSocketTest, ConnectServerClosesConn) { server_.SetRequestAction(TestHttpServer::kClose); CloseListener listener(NULL); ASSERT_FALSE(CreateWebSocket(server_.web_socket_url(), &listener)); } TEST_F(WebSocketTest, CloseOnReceive) { server_.SetMessageAction(TestHttpServer::kCloseOnMessage); base::RunLoop run_loop; CloseListener listener(&run_loop); scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); ASSERT_TRUE(sock); ASSERT_TRUE(sock->Send("hi")); loop_.PostDelayedTask( FROM_HERE, run_loop.QuitClosure(), base::TimeDelta::FromSeconds(10)); run_loop.Run(); } TEST_F(WebSocketTest, CloseOnSend) { base::RunLoop run_loop; CloseListener listener(&run_loop); scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener)); ASSERT_TRUE(sock); server_.Stop(); sock->Send("hi"); loop_.PostDelayedTask( FROM_HERE, run_loop.QuitClosure(), base::TimeDelta::FromSeconds(10)); run_loop.Run(); ASSERT_FALSE(sock->Send("hi")); } TEST_F(WebSocketTest, SendReceive) { std::vector<std::string> messages; messages.push_back("hello"); SendReceive(messages); } TEST_F(WebSocketTest, SendReceiveLarge) { std::vector<std::string> messages; messages.push_back(std::string(10 << 20, 'a')); SendReceive(messages); } TEST_F(WebSocketTest, SendReceiveMultiple) { std::vector<std::string> messages; messages.push_back("1"); messages.push_back("2"); messages.push_back("3"); SendReceive(messages); }