// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "remoting/protocol/channel_multiplexer.h" #include "base/bind.h" #include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "net/base/net_errors.h" #include "net/socket/socket.h" #include "net/socket/stream_socket.h" #include "remoting/base/constants.h" #include "remoting/protocol/connection_tester.h" #include "remoting/protocol/fake_session.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" using testing::_; using testing::AtMost; using testing::InvokeWithoutArgs; namespace remoting { namespace protocol { namespace { const int kMessageSize = 1024; const int kMessages = 100; const char kMuxChannelName[] = "mux"; const char kTestChannelName[] = "test"; const char kTestChannelName2[] = "test2"; void QuitCurrentThread() { base::MessageLoop::current()->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); } class MockSocketCallback { public: MOCK_METHOD1(OnDone, void(int result)); }; class MockConnectCallback { public: MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket)); void OnConnected(scoped_ptr<net::StreamSocket> socket) { OnConnectedPtr(socket.release()); } }; } // namespace class ChannelMultiplexerTest : public testing::Test { public: void DeleteAll() { host_socket1_.reset(); host_socket2_.reset(); client_socket1_.reset(); client_socket2_.reset(); host_mux_.reset(); client_mux_.reset(); } void DeleteAfterSessionFail() { host_mux_->CancelChannelCreation(kTestChannelName2); DeleteAll(); } protected: virtual void SetUp() OVERRIDE { // Create pair of multiplexers and connect them to each other. host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName)); client_mux_.reset(new ChannelMultiplexer(&client_session_, kMuxChannelName)); } // Connect sockets to each other. Must be called after we've created at least // one channel with each multiplexer. void ConnectSockets() { FakeSocket* host_socket = host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName); FakeSocket* client_socket = client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName); host_socket->PairWith(client_socket); // Make writes asynchronous in one direction. host_socket->set_async_write(true); } void CreateChannel(const std::string& name, scoped_ptr<net::StreamSocket>* host_socket, scoped_ptr<net::StreamSocket>* client_socket) { int counter = 2; host_mux_->CreateStreamChannel(name, base::Bind( &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), host_socket, &counter)); client_mux_->CreateStreamChannel(name, base::Bind( &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), client_socket, &counter)); message_loop_.Run(); EXPECT_TRUE(host_socket->get()); EXPECT_TRUE(client_socket->get()); } void OnChannelConnected( scoped_ptr<net::StreamSocket>* storage, int* counter, scoped_ptr<net::StreamSocket> socket) { *storage = socket.Pass(); --(*counter); EXPECT_GE(*counter, 0); if (*counter == 0) QuitCurrentThread(); } scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) { scoped_refptr<net::IOBufferWithSize> result = new net::IOBufferWithSize(size); for (int i = 0; i< size; ++i) { result->data()[i] = rand() % 256; } return result; } base::MessageLoop message_loop_; FakeSession host_session_; FakeSession client_session_; scoped_ptr<ChannelMultiplexer> host_mux_; scoped_ptr<ChannelMultiplexer> client_mux_; scoped_ptr<net::StreamSocket> host_socket1_; scoped_ptr<net::StreamSocket> client_socket1_; scoped_ptr<net::StreamSocket> host_socket2_; scoped_ptr<net::StreamSocket> client_socket2_; }; TEST_F(ChannelMultiplexerTest, OneChannel) { scoped_ptr<net::StreamSocket> host_socket; scoped_ptr<net::StreamSocket> client_socket; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName, &host_socket, &client_socket)); ConnectSockets(); StreamConnectionTester tester(host_socket.get(), client_socket.get(), kMessageSize, kMessages); tester.Start(); message_loop_.Run(); tester.CheckResults(); } TEST_F(ChannelMultiplexerTest, TwoChannels) { scoped_ptr<net::StreamSocket> host_socket1_; scoped_ptr<net::StreamSocket> client_socket1_; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); scoped_ptr<net::StreamSocket> host_socket2_; scoped_ptr<net::StreamSocket> client_socket2_; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); ConnectSockets(); StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), kMessageSize, kMessages); StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), kMessageSize, kMessages); tester1.Start(); tester2.Start(); while (!tester1.done() || !tester2.done()) { message_loop_.Run(); } tester1.CheckResults(); tester2.CheckResults(); } // Four channels, two in each direction TEST_F(ChannelMultiplexerTest, FourChannels) { scoped_ptr<net::StreamSocket> host_socket1_; scoped_ptr<net::StreamSocket> client_socket1_; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); scoped_ptr<net::StreamSocket> host_socket2_; scoped_ptr<net::StreamSocket> client_socket2_; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); scoped_ptr<net::StreamSocket> host_socket3; scoped_ptr<net::StreamSocket> client_socket3; ASSERT_NO_FATAL_FAILURE( CreateChannel("test3", &host_socket3, &client_socket3)); scoped_ptr<net::StreamSocket> host_socket4; scoped_ptr<net::StreamSocket> client_socket4; ASSERT_NO_FATAL_FAILURE( CreateChannel("ch4", &host_socket4, &client_socket4)); ConnectSockets(); StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), kMessageSize, kMessages); StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), kMessageSize, kMessages); StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(), kMessageSize, kMessages); StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(), kMessageSize, kMessages); tester1.Start(); tester2.Start(); tester3.Start(); tester4.Start(); while (!tester1.done() || !tester2.done() || !tester3.done() || !tester4.done()) { message_loop_.Run(); } tester1.CheckResults(); tester2.CheckResults(); tester3.CheckResults(); tester4.CheckResults(); } TEST_F(ChannelMultiplexerTest, WriteFailSync) { scoped_ptr<net::StreamSocket> host_socket1_; scoped_ptr<net::StreamSocket> client_socket1_; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); scoped_ptr<net::StreamSocket> host_socket2_; scoped_ptr<net::StreamSocket> client_socket2_; ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); ConnectSockets(); host_session_.GetStreamChannel(kMuxChannelName)-> set_next_write_error(net::ERR_FAILED); host_session_.GetStreamChannel(kMuxChannelName)-> set_async_write(false); scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); MockSocketCallback cb1; MockSocketCallback cb2; EXPECT_CALL(cb1, OnDone(_)) .Times(0); EXPECT_CALL(cb2, OnDone(_)) .Times(0); EXPECT_EQ(net::ERR_FAILED, host_socket1_->Write(buf.get(), buf->size(), base::Bind(&MockSocketCallback::OnDone, base::Unretained(&cb1)))); EXPECT_EQ(net::ERR_FAILED, host_socket2_->Write(buf.get(), buf->size(), base::Bind(&MockSocketCallback::OnDone, base::Unretained(&cb2)))); base::RunLoop().RunUntilIdle(); } TEST_F(ChannelMultiplexerTest, WriteFailAsync) { ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); ConnectSockets(); host_session_.GetStreamChannel(kMuxChannelName)-> set_next_write_error(net::ERR_FAILED); host_session_.GetStreamChannel(kMuxChannelName)-> set_async_write(true); scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); MockSocketCallback cb1; MockSocketCallback cb2; EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)); EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)); EXPECT_EQ(net::ERR_IO_PENDING, host_socket1_->Write(buf.get(), buf->size(), base::Bind(&MockSocketCallback::OnDone, base::Unretained(&cb1)))); EXPECT_EQ(net::ERR_IO_PENDING, host_socket2_->Write(buf.get(), buf->size(), base::Bind(&MockSocketCallback::OnDone, base::Unretained(&cb2)))); base::RunLoop().RunUntilIdle(); } TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) { ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); ASSERT_NO_FATAL_FAILURE( CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); ConnectSockets(); host_session_.GetStreamChannel(kMuxChannelName)-> set_next_write_error(net::ERR_FAILED); host_session_.GetStreamChannel(kMuxChannelName)-> set_async_write(true); scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); MockSocketCallback cb1; MockSocketCallback cb2; EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)) .Times(AtMost(1)) .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)) .Times(AtMost(1)) .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); EXPECT_EQ(net::ERR_IO_PENDING, host_socket1_->Write(buf.get(), buf->size(), base::Bind(&MockSocketCallback::OnDone, base::Unretained(&cb1)))); EXPECT_EQ(net::ERR_IO_PENDING, host_socket2_->Write(buf.get(), buf->size(), base::Bind(&MockSocketCallback::OnDone, base::Unretained(&cb2)))); base::RunLoop().RunUntilIdle(); // Check that the sockets were destroyed. EXPECT_FALSE(host_mux_.get()); } TEST_F(ChannelMultiplexerTest, SessionFail) { host_session_.set_async_creation(true); host_session_.set_error(AUTHENTICATION_FAILED); MockConnectCallback cb1; MockConnectCallback cb2; host_mux_->CreateStreamChannel(kTestChannelName, base::Bind( &MockConnectCallback::OnConnected, base::Unretained(&cb1))); host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind( &MockConnectCallback::OnConnected, base::Unretained(&cb2))); EXPECT_CALL(cb1, OnConnectedPtr(NULL)) .Times(AtMost(1)) .WillOnce(InvokeWithoutArgs( this, &ChannelMultiplexerTest::DeleteAfterSessionFail)); EXPECT_CALL(cb2, OnConnectedPtr(_)) .Times(0); base::RunLoop().RunUntilIdle(); } } // namespace protocol } // namespace remoting