C++程序  |  378行  |  10.76 KB

/*
 * Copyright (C) 2016 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <errno.h>
#include <netdb.h>
#include <string.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/uio.h>

#include <linux/netlink.h>
#include <linux/sock_diag.h>
#include <linux/inet_diag.h>

#define LOG_TAG "Netd"

#include <android-base/strings.h>
#include <cutils/log.h>

#include "NetdConstants.h"
#include "SockDiag.h"

#include <chrono>

#ifndef SOCK_DESTROY
#define SOCK_DESTROY 21
#endif

namespace {

struct AddrinfoDeleter {
  void operator()(addrinfo *a) { if (a) freeaddrinfo(a); }
};

typedef std::unique_ptr<addrinfo, AddrinfoDeleter> ScopedAddrinfo;

int checkError(int fd) {
    struct {
        nlmsghdr h;
        nlmsgerr err;
    } __attribute__((__packed__)) ack;
    ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
    if (bytesread == -1) {
       // Read failed (error), or nothing to read (good).
       return (errno == EAGAIN) ? 0 : -errno;
    } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
        // We got an error. Consume it.
        recv(fd, &ack, sizeof(ack), 0);
        return ack.err.error;
    } else {
        // The kernel replied with something. Leave it to the caller.
        return 0;
    }
}

}  // namespace

bool SockDiag::open() {
    if (hasSocks()) {
        return false;
    }

    mSock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
    mWriteSock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
    if (!hasSocks()) {
        closeSocks();
        return false;
    }

    sockaddr_nl nl = { .nl_family = AF_NETLINK };
    if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
        (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
        closeSocks();
        return false;
    }

    return true;
}

int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states,
                              iovec *iov, int iovcnt) {
    struct {
        nlmsghdr nlh;
        inet_diag_req_v2 req;
    } __attribute__((__packed__)) request = {
        .nlh = {
            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
            .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
        },
        .req = {
            .sdiag_family = family,
            .sdiag_protocol = proto,
            .idiag_states = states,
        },
    };

    size_t len = 0;
    iov[0].iov_base = &request;
    iov[0].iov_len = sizeof(request);
    for (int i = 0; i < iovcnt; i++) {
        len += iov[i].iov_len;
    }
    request.nlh.nlmsg_len = len;

    if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
        return -errno;
    }

    return checkError(mSock);
}

int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
    iovec iov[] = {
        { nullptr, 0 },
    };
    return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
}

int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
    addrinfo hints = { .ai_flags = AI_NUMERICHOST };
    addrinfo *res;
    in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
    int ret;

    // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
    // doing string conversions when they're not necessary.
    if ((ret = getaddrinfo(addrstr, nullptr, &hints, &res)) != 0) {
        return -EINVAL;
    }

    // So we don't have to call freeaddrinfo on every failure path.
    ScopedAddrinfo resP(res);

    void *addr;
    uint8_t addrlen;
    if (res->ai_family == AF_INET && family == AF_INET) {
        in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
        addr = &ina;
        addrlen = sizeof(ina);
    } else if (res->ai_family == AF_INET && family == AF_INET6) {
        in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
        mapped.s6_addr32[3] = ina.s_addr;
        addr = &mapped;
        addrlen = sizeof(mapped);
    } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
        in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
        addr = &in6a;
        addrlen = sizeof(in6a);
    } else {
        return -EAFNOSUPPORT;
    }

    uint8_t prefixlen = addrlen * 8;
    uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
    uint8_t nojump = yesjump + 4;

    struct {
        nlattr nla;
        inet_diag_bc_op op;
        inet_diag_hostcond cond;
    } __attribute__((__packed__)) attrs = {
        .nla = {
            .nla_type = INET_DIAG_REQ_BYTECODE,
        },
        .op = {
            INET_DIAG_BC_S_COND,
            yesjump,
            nojump,
        },
        .cond = {
            family,
            prefixlen,
            -1,
            {}
        },
    };

    attrs.nla.nla_len = sizeof(attrs) + addrlen;

    iovec iov[] = {
        { nullptr, 0 },
        { &attrs, sizeof(attrs) },
        { addr, addrlen },
    };

    uint32_t states = ~(1 << TCP_TIME_WAIT);
    return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
}

int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
    char buf[kBufferSize];

    ssize_t bytesread;
    do {
        bytesread = read(mSock, buf, sizeof(buf));

        if (bytesread < 0) {
            return -errno;
        }

        uint32_t len = bytesread;
        for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf);
             NLMSG_OK(nlh, len);
             nlh = NLMSG_NEXT(nlh, len)) {
            switch (nlh->nlmsg_type) {
              case NLMSG_DONE:
                callback(proto, NULL);
                return 0;
              case NLMSG_ERROR: {
                nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
                return err->error;
              }
              default:
                inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
                if (callback(proto, msg)) {
                    sockDestroy(proto, msg);
                }
            }
        }
    } while (bytesread > 0);

    return 0;
}

int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
    if (msg == nullptr) {
       return 0;
    }

    DestroyRequest request = {
        .nlh = {
            .nlmsg_type = SOCK_DESTROY,
            .nlmsg_flags = NLM_F_REQUEST,
        },
        .req = {
            .sdiag_family = msg->idiag_family,
            .sdiag_protocol = proto,
            .idiag_states = (uint32_t) (1 << msg->idiag_state),
            .id = msg->id,
        },
    };
    request.nlh.nlmsg_len = sizeof(request);

    if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
        return -errno;
    }

    int ret = checkError(mWriteSock);
    if (!ret) mSocketsDestroyed++;
    return ret;
}

int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
    if (!hasSocks()) {
        return -EBADFD;
    }

    if (int ret = sendDumpRequest(proto, family, addrstr)) {
        return ret;
    }

    auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };

    return readDiagMsg(proto, destroyAll);
}

int SockDiag::destroySockets(const char *addrstr) {
    Stopwatch s;
    mSocketsDestroyed = 0;

    if (!strchr(addrstr, ':')) {
        if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
            ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
            return ret;
        }
    }
    if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
        ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
        return ret;
    }

    if (mSocketsDestroyed > 0) {
        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
    }

    return mSocketsDestroyed;
}

int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
    int proto = IPPROTO_TCP;

    for (const int family : {AF_INET, AF_INET6}) {
        const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
        if (int ret = sendDumpRequest(proto, family, states)) {
            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
            return ret;
        }
        if (int ret = readDiagMsg(proto, destroyFilter)) {
            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
            return ret;
        }
    }

    return 0;
}

int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
    mSocketsDestroyed = 0;
    Stopwatch s;

    auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
        return (msg != nullptr && msg->idiag_uid == uid);
    };

    for (const int family : {AF_INET, AF_INET6}) {
        const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
        if (int ret = sendDumpRequest(proto, family, states)) {
            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
            return ret;
        }
        if (int ret = readDiagMsg(proto, shouldDestroy)) {
            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
            return ret;
        }
    }

    if (mSocketsDestroyed > 0) {
        ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
    }

    return 0;
}

int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
    mSocketsDestroyed = 0;
    Stopwatch s;

    auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
        return msg != nullptr &&
               uidRanges.hasUid(msg->idiag_uid) &&
               skipUids.find(msg->idiag_uid) == skipUids.end();
    };

    if (int ret = destroyLiveSockets(shouldDestroy)) {
        return ret;
    }

    std::vector<uid_t> skipUidStrings;
    for (uid_t uid : skipUids) {
        skipUidStrings.push_back(uid);
    }
    std::sort(skipUidStrings.begin(), skipUidStrings.end());

    if (mSocketsDestroyed > 0) {
        ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
              mSocketsDestroyed, uidRanges.toString().c_str(),
              android::base::Join(skipUidStrings, " ").c_str(), s.timeTaken());
    }

    return 0;
}