/* * Copyright 2004 The WebRTC Project Authors. All rights reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "webrtc/base/firewallsocketserver.h" #include <assert.h> #include <algorithm> #include "webrtc/base/asyncsocket.h" #include "webrtc/base/logging.h" namespace rtc { class FirewallSocket : public AsyncSocketAdapter { public: FirewallSocket(FirewallSocketServer* server, AsyncSocket* socket, int type) : AsyncSocketAdapter(socket), server_(server), type_(type) { } virtual int Connect(const SocketAddress& addr) { if (type_ == SOCK_STREAM) { if (!server_->Check(FP_TCP, GetLocalAddress(), addr)) { LOG(LS_VERBOSE) << "FirewallSocket outbound TCP connection from " << GetLocalAddress().ToSensitiveString() << " to " << addr.ToSensitiveString() << " denied"; // TODO: Handle this asynchronously. SetError(EHOSTUNREACH); return SOCKET_ERROR; } } return AsyncSocketAdapter::Connect(addr); } virtual int Send(const void* pv, size_t cb) { return SendTo(pv, cb, GetRemoteAddress()); } virtual int SendTo(const void* pv, size_t cb, const SocketAddress& addr) { if (type_ == SOCK_DGRAM) { if (!server_->Check(FP_UDP, GetLocalAddress(), addr)) { LOG(LS_VERBOSE) << "FirewallSocket outbound UDP packet from " << GetLocalAddress().ToSensitiveString() << " to " << addr.ToSensitiveString() << " dropped"; return static_cast<int>(cb); } } return AsyncSocketAdapter::SendTo(pv, cb, addr); } virtual int Recv(void* pv, size_t cb) { SocketAddress addr; return RecvFrom(pv, cb, &addr); } virtual int RecvFrom(void* pv, size_t cb, SocketAddress* paddr) { if (type_ == SOCK_DGRAM) { while (true) { int res = AsyncSocketAdapter::RecvFrom(pv, cb, paddr); if (res <= 0) return res; if (server_->Check(FP_UDP, *paddr, GetLocalAddress())) return res; LOG(LS_VERBOSE) << "FirewallSocket inbound UDP packet from " << paddr->ToSensitiveString() << " to " << GetLocalAddress().ToSensitiveString() << " dropped"; } } return AsyncSocketAdapter::RecvFrom(pv, cb, paddr); } virtual int Listen(int backlog) { if (!server_->tcp_listen_enabled()) { LOG(LS_VERBOSE) << "FirewallSocket listen attempt denied"; return -1; } return AsyncSocketAdapter::Listen(backlog); } virtual AsyncSocket* Accept(SocketAddress* paddr) { SocketAddress addr; while (AsyncSocket* sock = AsyncSocketAdapter::Accept(&addr)) { if (server_->Check(FP_TCP, addr, GetLocalAddress())) { if (paddr) *paddr = addr; return sock; } sock->Close(); delete sock; LOG(LS_VERBOSE) << "FirewallSocket inbound TCP connection from " << addr.ToSensitiveString() << " to " << GetLocalAddress().ToSensitiveString() << " denied"; } return 0; } private: FirewallSocketServer* server_; int type_; }; FirewallSocketServer::FirewallSocketServer(SocketServer* server, FirewallManager* manager, bool should_delete_server) : server_(server), manager_(manager), should_delete_server_(should_delete_server), udp_sockets_enabled_(true), tcp_sockets_enabled_(true), tcp_listen_enabled_(true) { if (manager_) manager_->AddServer(this); } FirewallSocketServer::~FirewallSocketServer() { if (manager_) manager_->RemoveServer(this); if (server_ && should_delete_server_) { delete server_; server_ = NULL; } } void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p, FirewallDirection d, const SocketAddress& addr) { SocketAddress src, dst; if (d == FD_IN) { dst = addr; } else { src = addr; } AddRule(allow, p, src, dst); } void FirewallSocketServer::AddRule(bool allow, FirewallProtocol p, const SocketAddress& src, const SocketAddress& dst) { Rule r; r.allow = allow; r.p = p; r.src = src; r.dst = dst; CritScope scope(&crit_); rules_.push_back(r); } void FirewallSocketServer::ClearRules() { CritScope scope(&crit_); rules_.clear(); } bool FirewallSocketServer::Check(FirewallProtocol p, const SocketAddress& src, const SocketAddress& dst) { CritScope scope(&crit_); for (size_t i = 0; i < rules_.size(); ++i) { const Rule& r = rules_[i]; if ((r.p != p) && (r.p != FP_ANY)) continue; if ((r.src.ipaddr() != src.ipaddr()) && !r.src.IsNil()) continue; if ((r.src.port() != src.port()) && (r.src.port() != 0)) continue; if ((r.dst.ipaddr() != dst.ipaddr()) && !r.dst.IsNil()) continue; if ((r.dst.port() != dst.port()) && (r.dst.port() != 0)) continue; return r.allow; } return true; } Socket* FirewallSocketServer::CreateSocket(int type) { return CreateSocket(AF_INET, type); } Socket* FirewallSocketServer::CreateSocket(int family, int type) { return WrapSocket(server_->CreateAsyncSocket(family, type), type); } AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int type) { return CreateAsyncSocket(AF_INET, type); } AsyncSocket* FirewallSocketServer::CreateAsyncSocket(int family, int type) { return WrapSocket(server_->CreateAsyncSocket(family, type), type); } AsyncSocket* FirewallSocketServer::WrapSocket(AsyncSocket* sock, int type) { if (!sock || (type == SOCK_STREAM && !tcp_sockets_enabled_) || (type == SOCK_DGRAM && !udp_sockets_enabled_)) { LOG(LS_VERBOSE) << "FirewallSocketServer socket creation denied"; delete sock; return NULL; } return new FirewallSocket(this, sock, type); } FirewallManager::FirewallManager() { } FirewallManager::~FirewallManager() { assert(servers_.empty()); } void FirewallManager::AddServer(FirewallSocketServer* server) { CritScope scope(&crit_); servers_.push_back(server); } void FirewallManager::RemoveServer(FirewallSocketServer* server) { CritScope scope(&crit_); servers_.erase(std::remove(servers_.begin(), servers_.end(), server), servers_.end()); } void FirewallManager::AddRule(bool allow, FirewallProtocol p, FirewallDirection d, const SocketAddress& addr) { CritScope scope(&crit_); for (std::vector<FirewallSocketServer*>::const_iterator it = servers_.begin(); it != servers_.end(); ++it) { (*it)->AddRule(allow, p, d, addr); } } void FirewallManager::ClearRules() { CritScope scope(&crit_); for (std::vector<FirewallSocketServer*>::const_iterator it = servers_.begin(); it != servers_.end(); ++it) { (*it)->ClearRules(); } } } // namespace rtc