/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include <atomic> #include <deque> #include <iostream> #include <mutex> #include <arpa/inet.h> #include <gmock/gmock.h> #include <gtest/gtest.h> #include <linux/netfilter/nfnetlink_log.h> #include <netdutils/MockSyscalls.h> #include "NFLogListener.h" using ::testing::ByMove; using ::testing::Exactly; using ::testing::Invoke; using ::testing::Mock; using ::testing::SaveArg; using ::testing::DoAll; using ::testing::Return; using ::testing::StrictMock; using ::testing::_; namespace android { namespace net { using netdutils::Fd; using netdutils::Slice; using netdutils::StatusOr; using netdutils::UniqueFd; using netdutils::forEachNetlinkAttribute; using netdutils::makeSlice; using netdutils::status::ok; constexpr int kNFLogPacketMsgType = (NFNL_SUBSYS_ULOG << 8) | NFULNL_MSG_PACKET; constexpr int kNetlinkMsgDoneType = (NFNL_SUBSYS_NONE << 8) | NLMSG_DONE; class MockNetlinkListener : public NetlinkListenerInterface { public: ~MockNetlinkListener() override = default; MOCK_METHOD1(send, netdutils::Status(const netdutils::Slice msg)); MOCK_METHOD2(subscribe, netdutils::Status(uint16_t type, const DispatchFn& fn)); MOCK_METHOD1(unsubscribe, netdutils::Status(uint16_t type)); MOCK_METHOD0(join, void()); }; class NFLogListenerTest : public testing::Test { protected: NFLogListenerTest() { EXPECT_CALL(*mNLListener, subscribe(kNFLogPacketMsgType, _)) .WillOnce(DoAll(SaveArg<1>(&mPacketFn), Return(ok))); EXPECT_CALL(*mNLListener, subscribe(kNetlinkMsgDoneType, _)) .WillOnce(DoAll(SaveArg<1>(&mDoneFn), Return(ok))); mListener.reset(new NFLogListener(mNLListener)); } ~NFLogListenerTest() { EXPECT_CALL(*mNLListener, unsubscribe(kNFLogPacketMsgType)).WillOnce(Return(ok)); EXPECT_CALL(*mNLListener, unsubscribe(kNetlinkMsgDoneType)).WillOnce(Return(ok)); } static StatusOr<size_t> sendOk(const Slice buf) { return buf.size(); } void subscribe(uint16_t type, NFLogListenerInterface::DispatchFn fn) { // Two sends for cfgCmdBind() & cfgMode(), one send at destruction time for cfgCmdUnbind() EXPECT_CALL(*mNLListener, send(_)).Times(Exactly(3)).WillRepeatedly(Invoke(sendOk)); mListener->subscribe(type, fn); } void sendEmptyMsg(uint16_t type) { struct { nlmsghdr nlmsg; nfgenmsg nfmsg; } msg = {}; msg.nlmsg.nlmsg_type = kNFLogPacketMsgType; msg.nlmsg.nlmsg_len = sizeof(msg); msg.nfmsg.res_id = htons(type); mPacketFn(msg.nlmsg, drop(makeSlice(msg), sizeof(msg.nlmsg))); } NetlinkListenerInterface::DispatchFn mPacketFn; NetlinkListenerInterface::DispatchFn mDoneFn; std::shared_ptr<StrictMock<MockNetlinkListener>> mNLListener{ new StrictMock<MockNetlinkListener>()}; std::unique_ptr<NFLogListener> mListener; }; TEST_F(NFLogListenerTest, subscribe) { constexpr uint16_t kType = 38; const auto dispatchFn = [](const nlmsghdr&, const nfgenmsg&, const netdutils::Slice) {}; subscribe(kType, dispatchFn); } TEST_F(NFLogListenerTest, nlmsgDone) { constexpr uint16_t kType = 38; const auto dispatchFn = [](const nlmsghdr&, const nfgenmsg&, const netdutils::Slice) {}; subscribe(kType, dispatchFn); mDoneFn({}, {}); } TEST_F(NFLogListenerTest, dispatchOk) { int invocations = 0; constexpr uint16_t kType = 38; const auto dispatchFn = [&invocations, kType](const nlmsghdr&, const nfgenmsg& nfmsg, const netdutils::Slice) { EXPECT_EQ(kType, ntohs(nfmsg.res_id)); ++invocations; }; subscribe(kType, dispatchFn); sendEmptyMsg(kType); EXPECT_EQ(1, invocations); } TEST_F(NFLogListenerTest, dispatchUnknownType) { constexpr uint16_t kType = 38; constexpr uint16_t kBadType = kType + 1; const auto dispatchFn = [](const nlmsghdr&, const nfgenmsg&, const netdutils::Slice) { // Expect no invocations ASSERT_TRUE(false); }; subscribe(kType, dispatchFn); sendEmptyMsg(kBadType); } } // namespace net } // namespace android