//
// Copyright (C) 2012 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "shill/net/rtnl_handler.h"
#include <string>
#include <gtest/gtest.h>
#include <net/if.h>
#include <sys/socket.h>
#include <linux/netlink.h> // Needs typedefs from sys/socket.h.
#include <linux/rtnetlink.h>
#include <sys/ioctl.h>
#include <base/bind.h>
#include "shill/mock_log.h"
#include "shill/net/mock_io_handler_factory.h"
#include "shill/net/mock_sockets.h"
#include "shill/net/rtnl_message.h"
using base::Bind;
using base::Callback;
using base::Unretained;
using std::string;
using testing::_;
using testing::A;
using testing::DoAll;
using testing::ElementsAre;
using testing::HasSubstr;
using testing::Return;
using testing::ReturnArg;
using testing::StrictMock;
using testing::Test;
namespace shill {
namespace {
const int kTestInterfaceIndex = 4;
ACTION(SetInterfaceIndex) {
if (arg2) {
reinterpret_cast<struct ifreq*>(arg2)->ifr_ifindex = kTestInterfaceIndex;
}
}
MATCHER_P(MessageType, message_type, "") {
return std::get<0>(arg).type() == message_type;
}
} // namespace
class RTNLHandlerTest : public Test {
public:
RTNLHandlerTest()
: sockets_(new StrictMock<MockSockets>()),
callback_(Bind(&RTNLHandlerTest::HandlerCallback, Unretained(this))),
dummy_message_(RTNLMessage::kTypeLink,
RTNLMessage::kModeGet,
0,
0,
0,
0,
IPAddress::kFamilyUnknown) {
}
virtual void SetUp() {
RTNLHandler::GetInstance()->io_handler_factory_ = &io_handler_factory_;
RTNLHandler::GetInstance()->sockets_.reset(sockets_);
}
virtual void TearDown() {
RTNLHandler::GetInstance()->Stop();
}
uint32_t GetRequestSequence() {
return RTNLHandler::GetInstance()->request_sequence_;
}
void SetRequestSequence(uint32_t sequence) {
RTNLHandler::GetInstance()->request_sequence_ = sequence;
}
bool IsSequenceInErrorMaskWindow(uint32_t sequence) {
return RTNLHandler::GetInstance()->IsSequenceInErrorMaskWindow(sequence);
}
void SetErrorMask(uint32_t sequence,
const RTNLHandler::ErrorMask& error_mask) {
return RTNLHandler::GetInstance()->SetErrorMask(sequence, error_mask);
}
RTNLHandler::ErrorMask GetAndClearErrorMask(uint32_t sequence) {
return RTNLHandler::GetInstance()->GetAndClearErrorMask(sequence);
}
int GetErrorWindowSize() {
return RTNLHandler::kErrorWindowSize;
}
MOCK_METHOD1(HandlerCallback, void(const RTNLMessage&));
protected:
static const int kTestSocket;
static const int kTestDeviceIndex;
static const char kTestDeviceName[];
void AddLink();
void AddNeighbor();
void StartRTNLHandler();
void StopRTNLHandler();
void ReturnError(uint32_t sequence, int error_number);
MockSockets* sockets_;
StrictMock<MockIOHandlerFactory> io_handler_factory_;
Callback<void(const RTNLMessage&)> callback_;
RTNLMessage dummy_message_;
};
const int RTNLHandlerTest::kTestSocket = 123;
const int RTNLHandlerTest::kTestDeviceIndex = 123456;
const char RTNLHandlerTest::kTestDeviceName[] = "test-device";
void RTNLHandlerTest::StartRTNLHandler() {
EXPECT_CALL(*sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE))
.WillOnce(Return(kTestSocket));
EXPECT_CALL(*sockets_, Bind(kTestSocket, _, sizeof(sockaddr_nl)))
.WillOnce(Return(0));
EXPECT_CALL(*sockets_, SetReceiveBuffer(kTestSocket, _)).WillOnce(Return(0));
EXPECT_CALL(io_handler_factory_, CreateIOInputHandler(kTestSocket, _, _));
RTNLHandler::GetInstance()->Start(0);
}
void RTNLHandlerTest::StopRTNLHandler() {
EXPECT_CALL(*sockets_, Close(kTestSocket)).WillOnce(Return(0));
RTNLHandler::GetInstance()->Stop();
}
void RTNLHandlerTest::AddLink() {
RTNLMessage message(RTNLMessage::kTypeLink,
RTNLMessage::kModeAdd,
0,
0,
0,
kTestDeviceIndex,
IPAddress::kFamilyIPv4);
message.SetAttribute(static_cast<uint16_t>(IFLA_IFNAME),
ByteString(string(kTestDeviceName), true));
ByteString b(message.Encode());
InputData data(b.GetData(), b.GetLength());
RTNLHandler::GetInstance()->ParseRTNL(&data);
}
void RTNLHandlerTest::AddNeighbor() {
RTNLMessage message(RTNLMessage::kTypeNeighbor,
RTNLMessage::kModeAdd,
0,
0,
0,
kTestDeviceIndex,
IPAddress::kFamilyIPv4);
ByteString encoded(message.Encode());
InputData data(encoded.GetData(), encoded.GetLength());
RTNLHandler::GetInstance()->ParseRTNL(&data);
}
void RTNLHandlerTest::ReturnError(uint32_t sequence, int error_number) {
struct {
struct nlmsghdr hdr;
struct nlmsgerr err;
} errmsg;
memset(&errmsg, 0, sizeof(errmsg));
errmsg.hdr.nlmsg_type = NLMSG_ERROR;
errmsg.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(errmsg.err));
errmsg.hdr.nlmsg_seq = sequence;
errmsg.err.error = -error_number;
InputData data(reinterpret_cast<unsigned char*>(&errmsg), sizeof(errmsg));
RTNLHandler::GetInstance()->ParseRTNL(&data);
}
TEST_F(RTNLHandlerTest, ListenersInvoked) {
StartRTNLHandler();
std::unique_ptr<RTNLListener> link_listener(
new RTNLListener(RTNLHandler::kRequestLink, callback_));
std::unique_ptr<RTNLListener> neighbor_listener(
new RTNLListener(RTNLHandler::kRequestNeighbor, callback_));
EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>()))
.With(MessageType(RTNLMessage::kTypeLink));
EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>()))
.With(MessageType(RTNLMessage::kTypeNeighbor));
AddLink();
AddNeighbor();
StopRTNLHandler();
}
TEST_F(RTNLHandlerTest, GetInterfaceName) {
EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(""));
{
struct ifreq ifr;
string name(sizeof(ifr.ifr_name), 'x');
EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(name));
}
const int kTestSocket = 123;
EXPECT_CALL(*sockets_, Socket(PF_INET, SOCK_DGRAM, 0))
.Times(3)
.WillOnce(Return(-1))
.WillRepeatedly(Return(kTestSocket));
EXPECT_CALL(*sockets_, Ioctl(kTestSocket, SIOCGIFINDEX, _))
.WillOnce(Return(-1))
.WillOnce(DoAll(SetInterfaceIndex(), Return(0)));
EXPECT_CALL(*sockets_, Close(kTestSocket))
.Times(2)
.WillRepeatedly(Return(0));
EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("eth0"));
EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("wlan0"));
EXPECT_EQ(kTestInterfaceIndex,
RTNLHandler::GetInstance()->GetInterfaceIndex("usb0"));
}
TEST_F(RTNLHandlerTest, IsSequenceInErrorMaskWindow) {
const uint32_t kRequestSequence = 1234;
SetRequestSequence(kRequestSequence);
EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence + 1));
EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence));
EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence - 1));
EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence -
GetErrorWindowSize() + 1));
EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence -
GetErrorWindowSize()));
EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence -
GetErrorWindowSize() - 1));
}
TEST_F(RTNLHandlerTest, SendMessageReturnsErrorAndAdvancesSequenceNumber) {
StartRTNLHandler();
const uint32_t kSequenceNumber = 123;
SetRequestSequence(kSequenceNumber);
EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(Return(-1));
EXPECT_FALSE(RTNLHandler::GetInstance()->SendMessage(&dummy_message_));
// Sequence number should still increment even if there was a failure.
EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
StopRTNLHandler();
}
TEST_F(RTNLHandlerTest, SendMessageWithEmptyMask) {
StartRTNLHandler();
const uint32_t kSequenceNumber = 123;
SetRequestSequence(kSequenceNumber);
SetErrorMask(kSequenceNumber, {1, 2, 3});
EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
&dummy_message_, {}));
EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty());
StopRTNLHandler();
}
TEST_F(RTNLHandlerTest, SendMessageWithErrorMask) {
StartRTNLHandler();
const uint32_t kSequenceNumber = 123;
SetRequestSequence(kSequenceNumber);
EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
&dummy_message_, {1, 2, 3}));
EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber + 1).empty());
EXPECT_THAT(GetAndClearErrorMask(kSequenceNumber), ElementsAre(1, 2, 3));
// A second call to GetAndClearErrorMask() returns an empty vector.
EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty());
StopRTNLHandler();
}
TEST_F(RTNLHandlerTest, SendMessageInferredErrorMasks) {
struct {
RTNLMessage::Type type;
RTNLMessage::Mode mode;
RTNLHandler::ErrorMask mask;
} expectations[] = {
{ RTNLMessage::kTypeLink, RTNLMessage::kModeGet, {} },
{ RTNLMessage::kTypeLink, RTNLMessage::kModeAdd, {EEXIST} },
{ RTNLMessage::kTypeLink, RTNLMessage::kModeDelete, {ESRCH, ENODEV} },
{ RTNLMessage::kTypeAddress, RTNLMessage::kModeDelete,
{ESRCH, ENODEV, EADDRNOTAVAIL} }
};
const uint32_t kSequenceNumber = 123;
EXPECT_CALL(*sockets_, Send(_, _, _, 0)).WillRepeatedly(ReturnArg<2>());
for (const auto& expectation : expectations) {
SetRequestSequence(kSequenceNumber);
RTNLMessage message(expectation.type,
expectation.mode,
0,
0,
0,
0,
IPAddress::kFamilyUnknown);
EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessage(&message));
EXPECT_EQ(expectation.mask, GetAndClearErrorMask(kSequenceNumber));
}
}
TEST_F(RTNLHandlerTest, MaskedError) {
StartRTNLHandler();
const uint32_t kSequenceNumber = 123;
SetRequestSequence(kSequenceNumber);
EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
&dummy_message_, {1, 2, 3}));
ScopedMockLog log;
// This error will be not be masked since this sequence number has no mask.
EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 1"))).Times(1);
ReturnError(kSequenceNumber - 1, 1);
// This error will be masked.
EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 2"))).Times(0);
ReturnError(kSequenceNumber, 2);
// This second error will be not be masked since the error mask was removed.
EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 3"))).Times(1);
ReturnError(kSequenceNumber, 3);
StopRTNLHandler();
}
} // namespace shill