// // Copyright (C) 2012 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // #include "shill/net/rtnl_handler.h" #include <string> #include <gtest/gtest.h> #include <net/if.h> #include <sys/socket.h> #include <linux/netlink.h> // Needs typedefs from sys/socket.h. #include <linux/rtnetlink.h> #include <sys/ioctl.h> #include <base/bind.h> #include "shill/mock_log.h" #include "shill/net/mock_io_handler_factory.h" #include "shill/net/mock_sockets.h" #include "shill/net/rtnl_message.h" using base::Bind; using base::Callback; using base::Unretained; using std::string; using testing::_; using testing::A; using testing::DoAll; using testing::ElementsAre; using testing::HasSubstr; using testing::Return; using testing::ReturnArg; using testing::StrictMock; using testing::Test; namespace shill { namespace { const int kTestInterfaceIndex = 4; ACTION(SetInterfaceIndex) { if (arg2) { reinterpret_cast<struct ifreq*>(arg2)->ifr_ifindex = kTestInterfaceIndex; } } MATCHER_P(MessageType, message_type, "") { return std::get<0>(arg).type() == message_type; } } // namespace class RTNLHandlerTest : public Test { public: RTNLHandlerTest() : sockets_(new StrictMock<MockSockets>()), callback_(Bind(&RTNLHandlerTest::HandlerCallback, Unretained(this))), dummy_message_(RTNLMessage::kTypeLink, RTNLMessage::kModeGet, 0, 0, 0, 0, IPAddress::kFamilyUnknown) { } virtual void SetUp() { RTNLHandler::GetInstance()->io_handler_factory_ = &io_handler_factory_; RTNLHandler::GetInstance()->sockets_.reset(sockets_); } virtual void TearDown() { RTNLHandler::GetInstance()->Stop(); } uint32_t GetRequestSequence() { return RTNLHandler::GetInstance()->request_sequence_; } void SetRequestSequence(uint32_t sequence) { RTNLHandler::GetInstance()->request_sequence_ = sequence; } bool IsSequenceInErrorMaskWindow(uint32_t sequence) { return RTNLHandler::GetInstance()->IsSequenceInErrorMaskWindow(sequence); } void SetErrorMask(uint32_t sequence, const RTNLHandler::ErrorMask& error_mask) { return RTNLHandler::GetInstance()->SetErrorMask(sequence, error_mask); } RTNLHandler::ErrorMask GetAndClearErrorMask(uint32_t sequence) { return RTNLHandler::GetInstance()->GetAndClearErrorMask(sequence); } int GetErrorWindowSize() { return RTNLHandler::kErrorWindowSize; } MOCK_METHOD1(HandlerCallback, void(const RTNLMessage&)); protected: static const int kTestSocket; static const int kTestDeviceIndex; static const char kTestDeviceName[]; void AddLink(); void AddNeighbor(); void StartRTNLHandler(); void StopRTNLHandler(); void ReturnError(uint32_t sequence, int error_number); MockSockets* sockets_; StrictMock<MockIOHandlerFactory> io_handler_factory_; Callback<void(const RTNLMessage&)> callback_; RTNLMessage dummy_message_; }; const int RTNLHandlerTest::kTestSocket = 123; const int RTNLHandlerTest::kTestDeviceIndex = 123456; const char RTNLHandlerTest::kTestDeviceName[] = "test-device"; void RTNLHandlerTest::StartRTNLHandler() { EXPECT_CALL(*sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE)) .WillOnce(Return(kTestSocket)); EXPECT_CALL(*sockets_, Bind(kTestSocket, _, sizeof(sockaddr_nl))) .WillOnce(Return(0)); EXPECT_CALL(*sockets_, SetReceiveBuffer(kTestSocket, _)).WillOnce(Return(0)); EXPECT_CALL(io_handler_factory_, CreateIOInputHandler(kTestSocket, _, _)); RTNLHandler::GetInstance()->Start(0); } void RTNLHandlerTest::StopRTNLHandler() { EXPECT_CALL(*sockets_, Close(kTestSocket)).WillOnce(Return(0)); RTNLHandler::GetInstance()->Stop(); } void RTNLHandlerTest::AddLink() { RTNLMessage message(RTNLMessage::kTypeLink, RTNLMessage::kModeAdd, 0, 0, 0, kTestDeviceIndex, IPAddress::kFamilyIPv4); message.SetAttribute(static_cast<uint16_t>(IFLA_IFNAME), ByteString(string(kTestDeviceName), true)); ByteString b(message.Encode()); InputData data(b.GetData(), b.GetLength()); RTNLHandler::GetInstance()->ParseRTNL(&data); } void RTNLHandlerTest::AddNeighbor() { RTNLMessage message(RTNLMessage::kTypeNeighbor, RTNLMessage::kModeAdd, 0, 0, 0, kTestDeviceIndex, IPAddress::kFamilyIPv4); ByteString encoded(message.Encode()); InputData data(encoded.GetData(), encoded.GetLength()); RTNLHandler::GetInstance()->ParseRTNL(&data); } void RTNLHandlerTest::ReturnError(uint32_t sequence, int error_number) { struct { struct nlmsghdr hdr; struct nlmsgerr err; } errmsg; memset(&errmsg, 0, sizeof(errmsg)); errmsg.hdr.nlmsg_type = NLMSG_ERROR; errmsg.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(errmsg.err)); errmsg.hdr.nlmsg_seq = sequence; errmsg.err.error = -error_number; InputData data(reinterpret_cast<unsigned char*>(&errmsg), sizeof(errmsg)); RTNLHandler::GetInstance()->ParseRTNL(&data); } TEST_F(RTNLHandlerTest, ListenersInvoked) { StartRTNLHandler(); std::unique_ptr<RTNLListener> link_listener( new RTNLListener(RTNLHandler::kRequestLink, callback_)); std::unique_ptr<RTNLListener> neighbor_listener( new RTNLListener(RTNLHandler::kRequestNeighbor, callback_)); EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>())) .With(MessageType(RTNLMessage::kTypeLink)); EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>())) .With(MessageType(RTNLMessage::kTypeNeighbor)); AddLink(); AddNeighbor(); StopRTNLHandler(); } TEST_F(RTNLHandlerTest, GetInterfaceName) { EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("")); { struct ifreq ifr; string name(sizeof(ifr.ifr_name), 'x'); EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(name)); } const int kTestSocket = 123; EXPECT_CALL(*sockets_, Socket(PF_INET, SOCK_DGRAM, 0)) .Times(3) .WillOnce(Return(-1)) .WillRepeatedly(Return(kTestSocket)); EXPECT_CALL(*sockets_, Ioctl(kTestSocket, SIOCGIFINDEX, _)) .WillOnce(Return(-1)) .WillOnce(DoAll(SetInterfaceIndex(), Return(0))); EXPECT_CALL(*sockets_, Close(kTestSocket)) .Times(2) .WillRepeatedly(Return(0)); EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("eth0")); EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("wlan0")); EXPECT_EQ(kTestInterfaceIndex, RTNLHandler::GetInstance()->GetInterfaceIndex("usb0")); } TEST_F(RTNLHandlerTest, IsSequenceInErrorMaskWindow) { const uint32_t kRequestSequence = 1234; SetRequestSequence(kRequestSequence); EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence + 1)); EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence)); EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence - 1)); EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence - GetErrorWindowSize() + 1)); EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence - GetErrorWindowSize())); EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence - GetErrorWindowSize() - 1)); } TEST_F(RTNLHandlerTest, SendMessageReturnsErrorAndAdvancesSequenceNumber) { StartRTNLHandler(); const uint32_t kSequenceNumber = 123; SetRequestSequence(kSequenceNumber); EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(Return(-1)); EXPECT_FALSE(RTNLHandler::GetInstance()->SendMessage(&dummy_message_)); // Sequence number should still increment even if there was a failure. EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence()); StopRTNLHandler(); } TEST_F(RTNLHandlerTest, SendMessageWithEmptyMask) { StartRTNLHandler(); const uint32_t kSequenceNumber = 123; SetRequestSequence(kSequenceNumber); SetErrorMask(kSequenceNumber, {1, 2, 3}); EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>()); EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask( &dummy_message_, {})); EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence()); EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty()); StopRTNLHandler(); } TEST_F(RTNLHandlerTest, SendMessageWithErrorMask) { StartRTNLHandler(); const uint32_t kSequenceNumber = 123; SetRequestSequence(kSequenceNumber); EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>()); EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask( &dummy_message_, {1, 2, 3})); EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence()); EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber + 1).empty()); EXPECT_THAT(GetAndClearErrorMask(kSequenceNumber), ElementsAre(1, 2, 3)); // A second call to GetAndClearErrorMask() returns an empty vector. EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty()); StopRTNLHandler(); } TEST_F(RTNLHandlerTest, SendMessageInferredErrorMasks) { struct { RTNLMessage::Type type; RTNLMessage::Mode mode; RTNLHandler::ErrorMask mask; } expectations[] = { { RTNLMessage::kTypeLink, RTNLMessage::kModeGet, {} }, { RTNLMessage::kTypeLink, RTNLMessage::kModeAdd, {EEXIST} }, { RTNLMessage::kTypeLink, RTNLMessage::kModeDelete, {ESRCH, ENODEV} }, { RTNLMessage::kTypeAddress, RTNLMessage::kModeDelete, {ESRCH, ENODEV, EADDRNOTAVAIL} } }; const uint32_t kSequenceNumber = 123; EXPECT_CALL(*sockets_, Send(_, _, _, 0)).WillRepeatedly(ReturnArg<2>()); for (const auto& expectation : expectations) { SetRequestSequence(kSequenceNumber); RTNLMessage message(expectation.type, expectation.mode, 0, 0, 0, 0, IPAddress::kFamilyUnknown); EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessage(&message)); EXPECT_EQ(expectation.mask, GetAndClearErrorMask(kSequenceNumber)); } } TEST_F(RTNLHandlerTest, MaskedError) { StartRTNLHandler(); const uint32_t kSequenceNumber = 123; SetRequestSequence(kSequenceNumber); EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>()); EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask( &dummy_message_, {1, 2, 3})); ScopedMockLog log; // This error will be not be masked since this sequence number has no mask. EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 1"))).Times(1); ReturnError(kSequenceNumber - 1, 1); // This error will be masked. EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 2"))).Times(0); ReturnError(kSequenceNumber, 2); // This second error will be not be masked since the error mask was removed. EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 3"))).Times(1); ReturnError(kSequenceNumber, 3); StopRTNLHandler(); } } // namespace shill