// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "remoting/protocol/connection_tester.h" #include "base/bind.h" #include "base/message_loop/message_loop.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/socket/stream_socket.h" #include "testing/gtest/include/gtest/gtest.h" namespace remoting { namespace protocol { StreamConnectionTester::StreamConnectionTester(net::StreamSocket* client_socket, net::StreamSocket* host_socket, int message_size, int message_count) : message_loop_(base::MessageLoop::current()), host_socket_(host_socket), client_socket_(client_socket), message_size_(message_size), test_data_size_(message_size * message_count), done_(false), write_errors_(0), read_errors_(0) { } StreamConnectionTester::~StreamConnectionTester() { } void StreamConnectionTester::Start() { InitBuffers(); DoRead(); DoWrite(); } void StreamConnectionTester::CheckResults() { EXPECT_EQ(0, write_errors_); EXPECT_EQ(0, read_errors_); ASSERT_EQ(test_data_size_, input_buffer_->offset()); output_buffer_->SetOffset(0); ASSERT_EQ(test_data_size_, output_buffer_->size()); EXPECT_EQ(0, memcmp(output_buffer_->data(), input_buffer_->StartOfBuffer(), test_data_size_)); } void StreamConnectionTester::Done() { done_ = true; message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); } void StreamConnectionTester::InitBuffers() { output_buffer_ = new net::DrainableIOBuffer( new net::IOBuffer(test_data_size_), test_data_size_); for (int i = 0; i < test_data_size_; ++i) { output_buffer_->data()[i] = static_cast<char>(i); } input_buffer_ = new net::GrowableIOBuffer(); } void StreamConnectionTester::DoWrite() { int result = 1; while (result > 0) { if (output_buffer_->BytesRemaining() == 0) break; int bytes_to_write = std::min(output_buffer_->BytesRemaining(), message_size_); result = client_socket_->Write( output_buffer_.get(), bytes_to_write, base::Bind(&StreamConnectionTester::OnWritten, base::Unretained(this))); HandleWriteResult(result); } } void StreamConnectionTester::OnWritten(int result) { HandleWriteResult(result); DoWrite(); } void StreamConnectionTester::HandleWriteResult(int result) { if (result <= 0 && result != net::ERR_IO_PENDING) { LOG(ERROR) << "Received error " << result << " when trying to write"; write_errors_++; Done(); } else if (result > 0) { output_buffer_->DidConsume(result); } } void StreamConnectionTester::DoRead() { int result = 1; while (result > 0) { input_buffer_->SetCapacity(input_buffer_->offset() + message_size_); result = host_socket_->Read( input_buffer_.get(), message_size_, base::Bind(&StreamConnectionTester::OnRead, base::Unretained(this))); HandleReadResult(result); }; } void StreamConnectionTester::OnRead(int result) { HandleReadResult(result); if (!done_) DoRead(); // Don't try to read again when we are done reading. } void StreamConnectionTester::HandleReadResult(int result) { if (result <= 0 && result != net::ERR_IO_PENDING) { LOG(ERROR) << "Received error " << result << " when trying to read"; read_errors_++; Done(); } else if (result > 0) { // Allocate memory for the next read. input_buffer_->set_offset(input_buffer_->offset() + result); if (input_buffer_->offset() == test_data_size_) Done(); } } DatagramConnectionTester::DatagramConnectionTester(net::Socket* client_socket, net::Socket* host_socket, int message_size, int message_count, int delay_ms) : message_loop_(base::MessageLoop::current()), host_socket_(host_socket), client_socket_(client_socket), message_size_(message_size), message_count_(message_count), delay_ms_(delay_ms), done_(false), write_errors_(0), read_errors_(0), packets_sent_(0), packets_received_(0), bad_packets_received_(0) { sent_packets_.resize(message_count_); } DatagramConnectionTester::~DatagramConnectionTester() { } void DatagramConnectionTester::Start() { DoRead(); DoWrite(); } void DatagramConnectionTester::CheckResults() { EXPECT_EQ(0, write_errors_); EXPECT_EQ(0, read_errors_); EXPECT_EQ(0, bad_packets_received_); // Verify that we've received at least one packet. EXPECT_GT(packets_received_, 0); VLOG(0) << "Received " << packets_received_ << " packets out of " << message_count_; } void DatagramConnectionTester::Done() { done_ = true; message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); } void DatagramConnectionTester::DoWrite() { if (packets_sent_ >= message_count_) { Done(); return; } scoped_refptr<net::IOBuffer> packet(new net::IOBuffer(message_size_)); for (int i = 0; i < message_size_; ++i) { packet->data()[i] = static_cast<char>(i); } sent_packets_[packets_sent_] = packet; // Put index of this packet in the beginning of the packet body. memcpy(packet->data(), &packets_sent_, sizeof(packets_sent_)); int result = client_socket_->Write( packet.get(), message_size_, base::Bind(&DatagramConnectionTester::OnWritten, base::Unretained(this))); HandleWriteResult(result); } void DatagramConnectionTester::OnWritten(int result) { HandleWriteResult(result); } void DatagramConnectionTester::HandleWriteResult(int result) { if (result <= 0 && result != net::ERR_IO_PENDING) { LOG(ERROR) << "Received error " << result << " when trying to write"; write_errors_++; Done(); } else if (result > 0) { EXPECT_EQ(message_size_, result); packets_sent_++; message_loop_->PostDelayedTask( FROM_HERE, base::Bind(&DatagramConnectionTester::DoWrite, base::Unretained(this)), base::TimeDelta::FromMilliseconds(delay_ms_)); } } void DatagramConnectionTester::DoRead() { int result = 1; while (result > 0) { int kReadSize = message_size_ * 2; read_buffer_ = new net::IOBuffer(kReadSize); result = host_socket_->Read( read_buffer_.get(), kReadSize, base::Bind(&DatagramConnectionTester::OnRead, base::Unretained(this))); HandleReadResult(result); }; } void DatagramConnectionTester::OnRead(int result) { HandleReadResult(result); DoRead(); } void DatagramConnectionTester::HandleReadResult(int result) { if (result <= 0 && result != net::ERR_IO_PENDING) { // Error will be received after the socket is closed. LOG(ERROR) << "Received error " << result << " when trying to read"; read_errors_++; Done(); } else if (result > 0) { packets_received_++; if (message_size_ != result) { // Invalid packet size; bad_packets_received_++; } else { // Validate packet body. int packet_id; memcpy(&packet_id, read_buffer_->data(), sizeof(packet_id)); if (packet_id < 0 || packet_id >= message_count_) { bad_packets_received_++; } else { if (memcmp(read_buffer_->data(), sent_packets_[packet_id]->data(), message_size_) != 0) bad_packets_received_++; } } } } } // namespace protocol } // namespace remoting