// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/udp/udp_client_socket.h"
#include "net/udp/udp_server_socket.h"
#include "base/basictypes.h"
#include "base/metrics/histogram.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/net_errors.h"
#include "net/base/net_test_suite.h"
#include "net/base/net_util.h"
#include "net/base/sys_addrinfo.h"
#include "net/base/test_completion_callback.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "testing/platform_test.h"
namespace net {
namespace {
class UDPSocketTest : public PlatformTest {
public:
UDPSocketTest()
: buffer_(new IOBufferWithSize(kMaxRead)) {
}
// Blocks until data is read from the socket.
std::string RecvFromSocket(UDPServerSocket* socket) {
TestCompletionCallback callback;
int rv = socket->RecvFrom(buffer_, kMaxRead, &recv_from_address_,
&callback);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
if (rv < 0)
return ""; // error!
return std::string(buffer_->data(), rv);
}
// Loop until |msg| has been written to the socket or until an
// error occurs.
// If |address| is specified, then it is used for the destination
// to send to. Otherwise, will send to the last socket this server
// received from.
int SendToSocket(UDPServerSocket* socket, std::string msg) {
return SendToSocket(socket, msg, recv_from_address_);
}
int SendToSocket(UDPServerSocket* socket,
std::string msg,
const IPEndPoint& address) {
TestCompletionCallback callback;
int length = msg.length();
scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
scoped_refptr<DrainableIOBuffer> buffer(
new DrainableIOBuffer(io_buffer, length));
int bytes_sent = 0;
while (buffer->BytesRemaining()) {
int rv = socket->SendTo(buffer, buffer->BytesRemaining(),
address, &callback);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
if (rv <= 0)
return bytes_sent > 0 ? bytes_sent : rv;
bytes_sent += rv;
buffer->DidConsume(rv);
}
return bytes_sent;
}
std::string ReadSocket(UDPClientSocket* socket) {
TestCompletionCallback callback;
int rv = socket->Read(buffer_, kMaxRead, &callback);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
if (rv < 0)
return ""; // error!
return std::string(buffer_->data(), rv);
}
// Loop until |msg| has been written to the socket or until an
// error occurs.
int WriteSocket(UDPClientSocket* socket, std::string msg) {
TestCompletionCallback callback;
int length = msg.length();
scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
scoped_refptr<DrainableIOBuffer> buffer(
new DrainableIOBuffer(io_buffer, length));
int bytes_sent = 0;
while (buffer->BytesRemaining()) {
int rv = socket->Write(buffer, buffer->BytesRemaining(), &callback);
if (rv == ERR_IO_PENDING)
rv = callback.WaitForResult();
if (rv <= 0)
return bytes_sent > 0 ? bytes_sent : rv;
bytes_sent += rv;
buffer->DidConsume(rv);
}
return bytes_sent;
}
protected:
static const int kMaxRead = 1024;
scoped_refptr<IOBufferWithSize> buffer_;
IPEndPoint recv_from_address_;
};
// Creates and address from an ip/port and returns it in |address|.
void CreateUDPAddress(std::string ip_str, int port, IPEndPoint* address) {
IPAddressNumber ip_number;
bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
if (!rv)
return;
*address = IPEndPoint(ip_number, port);
}
TEST_F(UDPSocketTest, Connect) {
const int kPort = 9999;
std::string simple_message("hello world!");
// Setup the server to listen.
IPEndPoint bind_address;
CreateUDPAddress("0.0.0.0", kPort, &bind_address);
UDPServerSocket server(NULL, NetLog::Source());
int rv = server.Listen(bind_address);
EXPECT_EQ(OK, rv);
// Setup the client.
IPEndPoint server_address;
CreateUDPAddress("127.0.0.1", kPort, &server_address);
UDPClientSocket client(NULL, NetLog::Source());
rv = client.Connect(server_address);
EXPECT_EQ(OK, rv);
// Client sends to the server.
rv = WriteSocket(&client, simple_message);
EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
// Server waits for message.
std::string str = RecvFromSocket(&server);
DCHECK(simple_message == str);
// Server echoes reply.
rv = SendToSocket(&server, simple_message);
EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
// Client waits for response.
str = ReadSocket(&client);
DCHECK(simple_message == str);
}
// In this test, we verify that connect() on a socket will have the effect
// of filtering reads on this socket only to data read from the destination
// we connected to.
//
// The purpose of this test is that some documentation indicates that connect
// binds the client's sends to send to a particular server endpoint, but does
// not bind the client's reads to only be from that endpoint, and that we need
// to always use recvfrom() to disambiguate.
TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
const int kPort1 = 9999;
const int kPort2 = 10000;
std::string simple_message("hello world!");
std::string foreign_message("BAD MESSAGE TO GET!!");
// Setup the first server to listen.
IPEndPoint bind_address;
CreateUDPAddress("0.0.0.0", kPort1, &bind_address);
UDPServerSocket server1(NULL, NetLog::Source());
int rv = server1.Listen(bind_address);
EXPECT_EQ(OK, rv);
// Setup the second server to listen.
CreateUDPAddress("0.0.0.0", kPort2, &bind_address);
UDPServerSocket server2(NULL, NetLog::Source());
rv = server2.Listen(bind_address);
EXPECT_EQ(OK, rv);
// Setup the client, connected to server 1.
IPEndPoint server_address;
CreateUDPAddress("127.0.0.1", kPort1, &server_address);
UDPClientSocket client(NULL, NetLog::Source());
rv = client.Connect(server_address);
EXPECT_EQ(OK, rv);
// Client sends to server1.
rv = WriteSocket(&client, simple_message);
EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
// Server1 waits for message.
std::string str = RecvFromSocket(&server1);
DCHECK(simple_message == str);
// Get the client's address.
IPEndPoint client_address;
rv = client.GetLocalAddress(&client_address);
EXPECT_EQ(OK, rv);
// Server2 sends reply.
rv = SendToSocket(&server2, foreign_message,
client_address);
EXPECT_EQ(foreign_message.length(), static_cast<size_t>(rv));
// Server1 sends reply.
rv = SendToSocket(&server1, simple_message,
client_address);
EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
// Client waits for response.
str = ReadSocket(&client);
DCHECK(simple_message == str);
}
TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
struct TestData {
std::string remote_address;
std::string local_address;
bool may_fail;
} tests[] = {
{ "127.0.00.1", "127.0.0.1", false },
{ "192.168.1.1", "127.0.0.1", false },
{ "::1", "::1", true },
{ "2001:db8:0::42", "::1", true },
};
for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); i++) {
SCOPED_TRACE(std::string("Connecting from ") + tests[i].local_address +
std::string(" to ") + tests[i].remote_address);
net::IPAddressNumber ip_number;
net::ParseIPLiteralToNumber(tests[i].remote_address, &ip_number);
net::IPEndPoint remote_address(ip_number, 80);
net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number);
net::IPEndPoint local_address(ip_number, 80);
UDPClientSocket client(NULL, NetLog::Source());
int rv = client.Connect(remote_address);
if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
// Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
// addresses if IPv6 is not configured.
continue;
}
EXPECT_LE(ERR_IO_PENDING, rv);
IPEndPoint fetched_local_address;
rv = client.GetLocalAddress(&fetched_local_address);
EXPECT_EQ(OK, rv);
// TODO(mbelshe): figure out how to verify the IP and port.
// The port is dynamically generated by the udp stack.
// The IP is the real IP of the client, not necessarily
// loopback.
//EXPECT_EQ(local_address.address(), fetched_local_address.address());
IPEndPoint fetched_remote_address;
rv = client.GetPeerAddress(&fetched_remote_address);
EXPECT_EQ(OK, rv);
EXPECT_EQ(remote_address, fetched_remote_address);
}
}
TEST_F(UDPSocketTest, ServerGetLocalAddress) {
IPEndPoint bind_address;
CreateUDPAddress("127.0.0.1", 0, &bind_address);
UDPServerSocket server(NULL, NetLog::Source());
int rv = server.Listen(bind_address);
EXPECT_EQ(OK, rv);
IPEndPoint local_address;
rv = server.GetLocalAddress(&local_address);
EXPECT_EQ(rv, 0);
// Verify that port was allocated.
EXPECT_GT(local_address.port(), 0);
EXPECT_EQ(local_address.address(), bind_address.address());
}
TEST_F(UDPSocketTest, ServerGetPeerAddress) {
IPEndPoint bind_address;
CreateUDPAddress("127.0.0.1", 0, &bind_address);
UDPServerSocket server(NULL, NetLog::Source());
int rv = server.Listen(bind_address);
EXPECT_EQ(OK, rv);
IPEndPoint peer_address;
rv = server.GetPeerAddress(&peer_address);
EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
}
// Close the socket while read is pending.
TEST_F(UDPSocketTest, CloseWithPendingRead) {
IPEndPoint bind_address;
CreateUDPAddress("127.0.0.1", 0, &bind_address);
UDPServerSocket server(NULL, NetLog::Source());
int rv = server.Listen(bind_address);
EXPECT_EQ(OK, rv);
TestCompletionCallback callback;
IPEndPoint from;
rv = server.RecvFrom(buffer_, kMaxRead, &from, &callback);
EXPECT_EQ(rv, ERR_IO_PENDING);
server.Close();
EXPECT_FALSE(callback.have_result());
}
} // namespace
} // namespace net