// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "remoting/protocol/channel_multiplexer.h"
#include "base/bind.h"
#include "base/message_loop/message_loop.h"
#include "base/run_loop.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
#include "net/socket/stream_socket.h"
#include "remoting/base/constants.h"
#include "remoting/protocol/connection_tester.h"
#include "remoting/protocol/fake_session.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::AtMost;
using testing::InvokeWithoutArgs;
namespace remoting {
namespace protocol {
namespace {
const int kMessageSize = 1024;
const int kMessages = 100;
const char kMuxChannelName[] = "mux";
const char kTestChannelName[] = "test";
const char kTestChannelName2[] = "test2";
void QuitCurrentThread() {
base::MessageLoop::current()->PostTask(FROM_HERE,
base::MessageLoop::QuitClosure());
}
class MockSocketCallback {
public:
MOCK_METHOD1(OnDone, void(int result));
};
class MockConnectCallback {
public:
MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket));
void OnConnected(scoped_ptr<net::StreamSocket> socket) {
OnConnectedPtr(socket.release());
}
};
} // namespace
class ChannelMultiplexerTest : public testing::Test {
public:
void DeleteAll() {
host_socket1_.reset();
host_socket2_.reset();
client_socket1_.reset();
client_socket2_.reset();
host_mux_.reset();
client_mux_.reset();
}
void DeleteAfterSessionFail() {
host_mux_->CancelChannelCreation(kTestChannelName2);
DeleteAll();
}
protected:
virtual void SetUp() OVERRIDE {
// Create pair of multiplexers and connect them to each other.
host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
client_mux_.reset(new ChannelMultiplexer(&client_session_,
kMuxChannelName));
}
// Connect sockets to each other. Must be called after we've created at least
// one channel with each multiplexer.
void ConnectSockets() {
FakeSocket* host_socket =
host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
FakeSocket* client_socket =
client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
host_socket->PairWith(client_socket);
// Make writes asynchronous in one direction.
host_socket->set_async_write(true);
}
void CreateChannel(const std::string& name,
scoped_ptr<net::StreamSocket>* host_socket,
scoped_ptr<net::StreamSocket>* client_socket) {
int counter = 2;
host_mux_->CreateStreamChannel(name, base::Bind(
&ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
host_socket, &counter));
client_mux_->CreateStreamChannel(name, base::Bind(
&ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
client_socket, &counter));
message_loop_.Run();
EXPECT_TRUE(host_socket->get());
EXPECT_TRUE(client_socket->get());
}
void OnChannelConnected(
scoped_ptr<net::StreamSocket>* storage,
int* counter,
scoped_ptr<net::StreamSocket> socket) {
*storage = socket.Pass();
--(*counter);
EXPECT_GE(*counter, 0);
if (*counter == 0)
QuitCurrentThread();
}
scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
scoped_refptr<net::IOBufferWithSize> result =
new net::IOBufferWithSize(size);
for (int i = 0; i< size; ++i) {
result->data()[i] = rand() % 256;
}
return result;
}
base::MessageLoop message_loop_;
FakeSession host_session_;
FakeSession client_session_;
scoped_ptr<ChannelMultiplexer> host_mux_;
scoped_ptr<ChannelMultiplexer> client_mux_;
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
};
TEST_F(ChannelMultiplexerTest, OneChannel) {
scoped_ptr<net::StreamSocket> host_socket;
scoped_ptr<net::StreamSocket> client_socket;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket, &client_socket));
ConnectSockets();
StreamConnectionTester tester(host_socket.get(), client_socket.get(),
kMessageSize, kMessages);
tester.Start();
message_loop_.Run();
tester.CheckResults();
}
TEST_F(ChannelMultiplexerTest, TwoChannels) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
kMessageSize, kMessages);
tester1.Start();
tester2.Start();
while (!tester1.done() || !tester2.done()) {
message_loop_.Run();
}
tester1.CheckResults();
tester2.CheckResults();
}
// Four channels, two in each direction
TEST_F(ChannelMultiplexerTest, FourChannels) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
scoped_ptr<net::StreamSocket> host_socket3;
scoped_ptr<net::StreamSocket> client_socket3;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("test3", &host_socket3, &client_socket3));
scoped_ptr<net::StreamSocket> host_socket4;
scoped_ptr<net::StreamSocket> client_socket4;
ASSERT_NO_FATAL_FAILURE(
CreateChannel("ch4", &host_socket4, &client_socket4));
ConnectSockets();
StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
kMessageSize, kMessages);
StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
kMessageSize, kMessages);
StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
kMessageSize, kMessages);
tester1.Start();
tester2.Start();
tester3.Start();
tester4.Start();
while (!tester1.done() || !tester2.done() ||
!tester3.done() || !tester4.done()) {
message_loop_.Run();
}
tester1.CheckResults();
tester2.CheckResults();
tester3.CheckResults();
tester4.CheckResults();
}
TEST_F(ChannelMultiplexerTest, WriteFailSync) {
scoped_ptr<net::StreamSocket> host_socket1_;
scoped_ptr<net::StreamSocket> client_socket1_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
scoped_ptr<net::StreamSocket> host_socket2_;
scoped_ptr<net::StreamSocket> client_socket2_;
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
set_async_write(false);
scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
MockSocketCallback cb1;
MockSocketCallback cb2;
EXPECT_CALL(cb1, OnDone(_))
.Times(0);
EXPECT_CALL(cb2, OnDone(_))
.Times(0);
EXPECT_EQ(net::ERR_FAILED,
host_socket1_->Write(buf.get(),
buf->size(),
base::Bind(&MockSocketCallback::OnDone,
base::Unretained(&cb1))));
EXPECT_EQ(net::ERR_FAILED,
host_socket2_->Write(buf.get(),
buf->size(),
base::Bind(&MockSocketCallback::OnDone,
base::Unretained(&cb2))));
base::RunLoop().RunUntilIdle();
}
TEST_F(ChannelMultiplexerTest, WriteFailAsync) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
set_async_write(true);
scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
MockSocketCallback cb1;
MockSocketCallback cb2;
EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket1_->Write(buf.get(),
buf->size(),
base::Bind(&MockSocketCallback::OnDone,
base::Unretained(&cb1))));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket2_->Write(buf.get(),
buf->size(),
base::Bind(&MockSocketCallback::OnDone,
base::Unretained(&cb2))));
base::RunLoop().RunUntilIdle();
}
TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
ASSERT_NO_FATAL_FAILURE(
CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
ConnectSockets();
host_session_.GetStreamChannel(kMuxChannelName)->
set_next_write_error(net::ERR_FAILED);
host_session_.GetStreamChannel(kMuxChannelName)->
set_async_write(true);
scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
MockSocketCallback cb1;
MockSocketCallback cb2;
EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
.Times(AtMost(1))
.WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
.Times(AtMost(1))
.WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket1_->Write(buf.get(),
buf->size(),
base::Bind(&MockSocketCallback::OnDone,
base::Unretained(&cb1))));
EXPECT_EQ(net::ERR_IO_PENDING,
host_socket2_->Write(buf.get(),
buf->size(),
base::Bind(&MockSocketCallback::OnDone,
base::Unretained(&cb2))));
base::RunLoop().RunUntilIdle();
// Check that the sockets were destroyed.
EXPECT_FALSE(host_mux_.get());
}
TEST_F(ChannelMultiplexerTest, SessionFail) {
host_session_.set_async_creation(true);
host_session_.set_error(AUTHENTICATION_FAILED);
MockConnectCallback cb1;
MockConnectCallback cb2;
host_mux_->CreateStreamChannel(kTestChannelName, base::Bind(
&MockConnectCallback::OnConnected, base::Unretained(&cb1)));
host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind(
&MockConnectCallback::OnConnected, base::Unretained(&cb2)));
EXPECT_CALL(cb1, OnConnectedPtr(NULL))
.Times(AtMost(1))
.WillOnce(InvokeWithoutArgs(
this, &ChannelMultiplexerTest::DeleteAfterSessionFail));
EXPECT_CALL(cb2, OnConnectedPtr(_))
.Times(0);
base::RunLoop().RunUntilIdle();
}
} // namespace protocol
} // namespace remoting