/* * Copyright 2004 The WebRTC Project Authors. All rights reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "webrtc/base/natsocketfactory.h" #include "webrtc/base/logging.h" #include "webrtc/base/natserver.h" #include "webrtc/base/virtualsocketserver.h" namespace rtc { // Packs the given socketaddress into the buffer in buf, in the quasi-STUN // format that the natserver uses. // Returns 0 if an invalid address is passed. size_t PackAddressForNAT(char* buf, size_t buf_size, const SocketAddress& remote_addr) { const IPAddress& ip = remote_addr.ipaddr(); int family = ip.family(); buf[0] = 0; buf[1] = family; // Writes the port. *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port()); if (family == AF_INET) { ASSERT(buf_size >= kNATEncodedIPv4AddressSize); in_addr v4addr = ip.ipv4_address(); memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4); return kNATEncodedIPv4AddressSize; } else if (family == AF_INET6) { ASSERT(buf_size >= kNATEncodedIPv6AddressSize); in6_addr v6addr = ip.ipv6_address(); memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4); return kNATEncodedIPv6AddressSize; } return 0U; } // Decodes the remote address from a packet that has been encoded with the nat's // quasi-STUN format. Returns the length of the address (i.e., the offset into // data where the original packet starts). size_t UnpackAddressFromNAT(const char* buf, size_t buf_size, SocketAddress* remote_addr) { ASSERT(buf_size >= 8); ASSERT(buf[0] == 0); int family = buf[1]; uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2]))); if (family == AF_INET) { const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]); *remote_addr = SocketAddress(IPAddress(*v4addr), port); return kNATEncodedIPv4AddressSize; } else if (family == AF_INET6) { ASSERT(buf_size >= 20); const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]); *remote_addr = SocketAddress(IPAddress(*v6addr), port); return kNATEncodedIPv6AddressSize; } return 0U; } // NATSocket class NATSocket : public AsyncSocket, public sigslot::has_slots<> { public: explicit NATSocket(NATInternalSocketFactory* sf, int family, int type) : sf_(sf), family_(family), type_(type), connected_(false), socket_(NULL), buf_(NULL), size_(0) { } virtual ~NATSocket() { delete socket_; delete[] buf_; } virtual SocketAddress GetLocalAddress() const { return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); } virtual SocketAddress GetRemoteAddress() const { return remote_addr_; // will be NIL if not connected } virtual int Bind(const SocketAddress& addr) { if (socket_) { // already bound, bubble up error return -1; } int result; socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); result = (socket_) ? socket_->Bind(addr) : -1; if (result >= 0) { socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); } else { server_addr_.Clear(); delete socket_; socket_ = NULL; } return result; } virtual int Connect(const SocketAddress& addr) { if (!socket_) { // socket must be bound, for now return -1; } int result = 0; if (type_ == SOCK_STREAM) { result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_); } else { connected_ = true; } if (result >= 0) { remote_addr_ = addr; } return result; } virtual int Send(const void* data, size_t size) { ASSERT(connected_); return SendTo(data, size, remote_addr_); } virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) { ASSERT(!connected_ || addr == remote_addr_); if (server_addr_.IsNil() || type_ == SOCK_STREAM) { return socket_->SendTo(data, size, addr); } // This array will be too large for IPv4 packets, but only by 12 bytes. scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]); size_t addrlength = PackAddressForNAT(buf.get(), size + kNATEncodedIPv6AddressSize, addr); size_t encoded_size = size + addrlength; memcpy(buf.get() + addrlength, data, size); int result = socket_->SendTo(buf.get(), encoded_size, server_addr_); if (result >= 0) { ASSERT(result == static_cast<int>(encoded_size)); result = result - static_cast<int>(addrlength); } return result; } virtual int Recv(void* data, size_t size) { SocketAddress addr; return RecvFrom(data, size, &addr); } virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) { if (server_addr_.IsNil() || type_ == SOCK_STREAM) { return socket_->RecvFrom(data, size, out_addr); } // Make sure we have enough room to read the requested amount plus the // largest possible header address. SocketAddress remote_addr; Grow(size + kNATEncodedIPv6AddressSize); // Read the packet from the socket. int result = socket_->RecvFrom(buf_, size_, &remote_addr); if (result >= 0) { ASSERT(remote_addr == server_addr_); // TODO: we need better framing so we know how many bytes we can // return before we need to read the next address. For UDP, this will be // fine as long as the reader always reads everything in the packet. ASSERT((size_t)result < size_); // Decode the wire packet into the actual results. SocketAddress real_remote_addr; size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr); memcpy(data, buf_ + addrlength, result - addrlength); // Make sure this packet should be delivered before returning it. if (!connected_ || (real_remote_addr == remote_addr_)) { if (out_addr) *out_addr = real_remote_addr; result = result - static_cast<int>(addrlength); } else { LOG(LS_ERROR) << "Dropping packet from unknown remote address: " << real_remote_addr.ToString(); result = 0; // Tell the caller we didn't read anything } } return result; } virtual int Close() { int result = 0; if (socket_) { result = socket_->Close(); if (result >= 0) { connected_ = false; remote_addr_ = SocketAddress(); delete socket_; socket_ = NULL; } } return result; } virtual int Listen(int backlog) { return socket_->Listen(backlog); } virtual AsyncSocket* Accept(SocketAddress *paddr) { return socket_->Accept(paddr); } virtual int GetError() const { return socket_->GetError(); } virtual void SetError(int error) { socket_->SetError(error); } virtual ConnState GetState() const { return connected_ ? CS_CONNECTED : CS_CLOSED; } virtual int EstimateMTU(uint16* mtu) { return socket_->EstimateMTU(mtu); } virtual int GetOption(Option opt, int* value) { return socket_->GetOption(opt, value); } virtual int SetOption(Option opt, int value) { return socket_->SetOption(opt, value); } void OnConnectEvent(AsyncSocket* socket) { // If we're NATed, we need to send a request with the real addr to use. ASSERT(socket == socket_); if (server_addr_.IsNil()) { connected_ = true; SignalConnectEvent(this); } else { SendConnectRequest(); } } void OnReadEvent(AsyncSocket* socket) { // If we're NATed, we need to process the connect reply. ASSERT(socket == socket_); if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) { HandleConnectReply(); } else { SignalReadEvent(this); } } void OnWriteEvent(AsyncSocket* socket) { ASSERT(socket == socket_); SignalWriteEvent(this); } void OnCloseEvent(AsyncSocket* socket, int error) { ASSERT(socket == socket_); SignalCloseEvent(this, error); } private: // Makes sure the buffer is at least the given size. void Grow(size_t new_size) { if (size_ < new_size) { delete[] buf_; size_ = new_size; buf_ = new char[size_]; } } // Sends the destination address to the server to tell it to connect. void SendConnectRequest() { char buf[256]; size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_); socket_->Send(buf, length); } // Handles the byte sent back from the server and fires the appropriate event. void HandleConnectReply() { char code; socket_->Recv(&code, sizeof(code)); if (code == 0) { SignalConnectEvent(this); } else { Close(); SignalCloseEvent(this, code); } } NATInternalSocketFactory* sf_; int family_; int type_; bool connected_; SocketAddress remote_addr_; SocketAddress server_addr_; // address of the NAT server AsyncSocket* socket_; char* buf_; size_t size_; }; // NATSocketFactory NATSocketFactory::NATSocketFactory(SocketFactory* factory, const SocketAddress& nat_addr) : factory_(factory), nat_addr_(nat_addr) { } Socket* NATSocketFactory::CreateSocket(int type) { return CreateSocket(AF_INET, type); } Socket* NATSocketFactory::CreateSocket(int family, int type) { return new NATSocket(this, family, type); } AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) { return CreateAsyncSocket(AF_INET, type); } AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) { return new NATSocket(this, family, type); } AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type, const SocketAddress& local_addr, SocketAddress* nat_addr) { *nat_addr = nat_addr_; return factory_->CreateAsyncSocket(family, type); } // NATSocketServer NATSocketServer::NATSocketServer(SocketServer* server) : server_(server), msg_queue_(NULL) { } NATSocketServer::Translator* NATSocketServer::GetTranslator( const SocketAddress& ext_ip) { return nats_.Get(ext_ip); } NATSocketServer::Translator* NATSocketServer::AddTranslator( const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { // Fail if a translator already exists with this extternal address. if (nats_.Get(ext_ip)) return NULL; return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); } void NATSocketServer::RemoveTranslator( const SocketAddress& ext_ip) { nats_.Remove(ext_ip); } Socket* NATSocketServer::CreateSocket(int type) { return CreateSocket(AF_INET, type); } Socket* NATSocketServer::CreateSocket(int family, int type) { return new NATSocket(this, family, type); } AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) { return CreateAsyncSocket(AF_INET, type); } AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) { return new NATSocket(this, family, type); } AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type, const SocketAddress& local_addr, SocketAddress* nat_addr) { AsyncSocket* socket = NULL; Translator* nat = nats_.FindClient(local_addr); if (nat) { socket = nat->internal_factory()->CreateAsyncSocket(family, type); *nat_addr = (type == SOCK_STREAM) ? nat->internal_tcp_address() : nat->internal_address(); } else { socket = server_->CreateAsyncSocket(family, type); } return socket; } // NATSocketServer::Translator NATSocketServer::Translator::Translator( NATSocketServer* server, NATType type, const SocketAddress& int_ip, SocketFactory* ext_factory, const SocketAddress& ext_ip) : server_(server) { // Create a new private network, and a NATServer running on the private // network that bridges to the external network. Also tell the private // network to use the same message queue as us. VirtualSocketServer* internal_server = new VirtualSocketServer(server_); internal_server->SetMessageQueue(server_->queue()); internal_factory_.reset(internal_server); nat_server_.reset(new NATServer(type, internal_server, int_ip, ext_factory, ext_ip)); } NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator( const SocketAddress& ext_ip) { return nats_.Get(ext_ip); } NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator( const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { // Fail if a translator already exists with this extternal address. if (nats_.Get(ext_ip)) return NULL; AddClient(ext_ip); return nats_.Add(ext_ip, new Translator(server_, type, int_ip, server_, ext_ip)); } void NATSocketServer::Translator::RemoveTranslator( const SocketAddress& ext_ip) { nats_.Remove(ext_ip); RemoveClient(ext_ip); } bool NATSocketServer::Translator::AddClient( const SocketAddress& int_ip) { // Fail if a client already exists with this internal address. if (clients_.find(int_ip) != clients_.end()) return false; clients_.insert(int_ip); return true; } void NATSocketServer::Translator::RemoveClient( const SocketAddress& int_ip) { std::set<SocketAddress>::iterator it = clients_.find(int_ip); if (it != clients_.end()) { clients_.erase(it); } } NATSocketServer::Translator* NATSocketServer::Translator::FindClient( const SocketAddress& int_ip) { // See if we have the requested IP, or any of our children do. return (clients_.find(int_ip) != clients_.end()) ? this : nats_.FindClient(int_ip); } // NATSocketServer::TranslatorMap NATSocketServer::TranslatorMap::~TranslatorMap() { for (TranslatorMap::iterator it = begin(); it != end(); ++it) { delete it->second; } } NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get( const SocketAddress& ext_ip) { TranslatorMap::iterator it = find(ext_ip); return (it != end()) ? it->second : NULL; } NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add( const SocketAddress& ext_ip, Translator* nat) { (*this)[ext_ip] = nat; return nat; } void NATSocketServer::TranslatorMap::Remove( const SocketAddress& ext_ip) { TranslatorMap::iterator it = find(ext_ip); if (it != end()) { delete it->second; erase(it); } } NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient( const SocketAddress& int_ip) { Translator* nat = NULL; for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) { nat = it->second->FindClient(int_ip); } return nat; } } // namespace rtc