// Copyright 2015 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include <arpa/inet.h> #include <map> #include <netdb.h> #include <string> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> #include <base/bind.h> #include <base/bind_helpers.h> #include <base/files/file_util.h> #include <base/message_loop/message_loop.h> #include <base/strings/stringprintf.h> #include <brillo/bind_lambda.h> #include <brillo/streams/file_stream.h> #include <brillo/streams/tls_stream.h> #include "buffet/socket_stream.h" #include "buffet/weave_error_conversion.h" namespace buffet { using weave::provider::Network; namespace { std::string GetIPAddress(const sockaddr* sa) { std::string addr; char str[INET6_ADDRSTRLEN] = {}; switch (sa->sa_family) { case AF_INET: if (inet_ntop(AF_INET, &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str, sizeof(str))) { addr = str; } break; case AF_INET6: if (inet_ntop(AF_INET6, &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr), str, sizeof(str))) { addr = str; } break; } if (addr.empty()) addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family); return addr; } int ConnectSocket(const std::string& host, uint16_t port) { std::string service = std::to_string(port); addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM}; addrinfo* result = nullptr; if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) { PLOG(WARNING) << "Failed to resolve host name: " << host; return -1; } int socket_fd = -1; for (const addrinfo* info = result; info != nullptr; info = info->ai_next) { socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol); if (socket_fd < 0) continue; std::string addr = GetIPAddress(info->ai_addr); LOG(INFO) << "Connecting to address: " << addr; if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0) break; // Success. PLOG(WARNING) << "Failed to connect to address: " << addr; close(socket_fd); socket_fd = -1; } freeaddrinfo(result); return socket_fd; } void OnSuccess(const Network::OpenSslSocketCallback& callback, brillo::StreamPtr tls_stream) { callback.Run( std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}}, nullptr); } void OnError(const weave::DoneCallback& callback, const brillo::Error* brillo_error) { weave::ErrorPtr error; ConvertError(*brillo_error, &error); callback.Run(std::move(error)); } } // namespace void SocketStream::Read(void* buffer, size_t size_to_read, const ReadCallback& callback) { brillo::ErrorPtr brillo_error; if (!ptr_->ReadAsync( buffer, size_to_read, base::Bind([](const ReadCallback& callback, size_t size) { callback.Run(size, nullptr); }, callback), base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) { weave::ErrorPtr error; ConvertError(*brillo_error, &error); base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(callback, 0, base::Passed(&error))); } } void SocketStream::Write(const void* buffer, size_t size_to_write, const WriteCallback& callback) { brillo::ErrorPtr brillo_error; if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr), base::Bind(&OnError, callback), &brillo_error)) { weave::ErrorPtr error; ConvertError(*brillo_error, &error); base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(callback, base::Passed(&error))); } } void SocketStream::CancelPendingOperations() { ptr_->CancelPendingAsyncOperations(); } std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking( const std::string& host, uint16_t port) { int socket_fd = ConnectSocket(host, port); if (socket_fd <= 0) return nullptr; auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr); if (ptr_) return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}}; close(socket_fd); return nullptr; } void SocketStream::TlsConnect(std::unique_ptr<Stream> socket, const std::string& host, const Network::OpenSslSocketCallback& callback) { SocketStream* stream = static_cast<SocketStream*>(socket.get()); brillo::TlsStream::Connect( std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback), base::Bind(&OnError, base::Bind(callback, nullptr))); } } // namespace buffet