// Copyright 2015 The Weave Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "examples/provider/ssl_stream.h" #include <openssl/err.h> #include <base/bind.h> #include <base/bind_helpers.h> #include <weave/provider/task_runner.h> namespace weave { namespace examples { namespace { void AddSslError(ErrorPtr* error, const tracked_objects::Location& location, const std::string& error_code, unsigned long ssl_error_code) { ERR_load_BIO_strings(); SSL_load_error_strings(); Error::AddToPrintf(error, location, error_code, "%s: %s", ERR_lib_error_string(ssl_error_code), ERR_reason_error_string(ssl_error_code)); } void RetryAsyncTask(provider::TaskRunner* task_runner, const tracked_objects::Location& location, const base::Closure& task) { task_runner->PostDelayedTask(FROM_HERE, task, base::TimeDelta::FromMilliseconds(100)); } } // namespace void SSLStream::SslDeleter::operator()(BIO* bio) const { BIO_free(bio); } void SSLStream::SslDeleter::operator()(SSL* ssl) const { SSL_free(ssl); } void SSLStream::SslDeleter::operator()(SSL_CTX* ctx) const { SSL_CTX_free(ctx); } SSLStream::SSLStream(provider::TaskRunner* task_runner, std::unique_ptr<BIO, SslDeleter> stream_bio) : task_runner_{task_runner} { ctx_.reset(SSL_CTX_new(TLSv1_2_client_method())); CHECK(ctx_); ssl_.reset(SSL_new(ctx_.get())); SSL_set_bio(ssl_.get(), stream_bio.get(), stream_bio.get()); stream_bio.release(); // Owned by ssl now. SSL_set_connect_state(ssl_.get()); } SSLStream::~SSLStream() { CancelPendingOperations(); } void SSLStream::RunTask(const base::Closure& task) { task.Run(); } void SSLStream::Read(void* buffer, size_t size_to_read, const ReadCallback& callback) { int res = SSL_read(ssl_.get(), buffer, size_to_read); if (res > 0) { task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, res, nullptr)), {}); return; } int err = SSL_get_error(ssl_.get(), res); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { return RetryAsyncTask( task_runner_, FROM_HERE, base::Bind(&SSLStream::Read, weak_ptr_factory_.GetWeakPtr(), buffer, size_to_read, callback)); } ErrorPtr weave_error; AddSslError(&weave_error, FROM_HERE, "read_failed", err); return task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, 0, base::Passed(&weave_error))), {}); } void SSLStream::Write(const void* buffer, size_t size_to_write, const WriteCallback& callback) { int res = SSL_write(ssl_.get(), buffer, size_to_write); if (res > 0) { buffer = static_cast<const char*>(buffer) + res; size_to_write -= res; if (size_to_write == 0) { return task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, nullptr)), {}); } return RetryAsyncTask( task_runner_, FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, size_to_write, callback)); } int err = SSL_get_error(ssl_.get(), res); if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) { return RetryAsyncTask( task_runner_, FROM_HERE, base::Bind(&SSLStream::Write, weak_ptr_factory_.GetWeakPtr(), buffer, size_to_write, callback)); } ErrorPtr weave_error; AddSslError(&weave_error, FROM_HERE, "write_failed", err); task_runner_->PostDelayedTask( FROM_HERE, base::Bind(&SSLStream::RunTask, weak_ptr_factory_.GetWeakPtr(), base::Bind(callback, base::Passed(&weave_error))), {}); } void SSLStream::CancelPendingOperations() { weak_ptr_factory_.InvalidateWeakPtrs(); } void SSLStream::Connect( provider::TaskRunner* task_runner, const std::string& host, uint16_t port, const provider::Network::OpenSslSocketCallback& callback) { SSL_library_init(); char end_point[255]; snprintf(end_point, sizeof(end_point), "%s:%u", host.c_str(), port); std::unique_ptr<BIO, SslDeleter> stream_bio(BIO_new_connect(end_point)); CHECK(stream_bio); BIO_set_nbio(stream_bio.get(), 1); std::unique_ptr<SSLStream> stream{ new SSLStream{task_runner, std::move(stream_bio)}}; ConnectBio(std::move(stream), callback); } void SSLStream::ConnectBio( std::unique_ptr<SSLStream> stream, const provider::Network::OpenSslSocketCallback& callback) { BIO* bio = SSL_get_rbio(stream->ssl_.get()); if (BIO_do_connect(bio) == 1) return DoHandshake(std::move(stream), callback); auto task_runner = stream->task_runner_; if (BIO_should_retry(bio)) { return RetryAsyncTask( task_runner, FROM_HERE, base::Bind(&SSLStream::ConnectBio, base::Passed(&stream), callback)); } ErrorPtr error; AddSslError(&error, FROM_HERE, "connect_failed", ERR_get_error()); task_runner->PostDelayedTask( FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } void SSLStream::DoHandshake( std::unique_ptr<SSLStream> stream, const provider::Network::OpenSslSocketCallback& callback) { int res = SSL_do_handshake(stream->ssl_.get()); auto task_runner = stream->task_runner_; if (res == 1) { return task_runner->PostDelayedTask( FROM_HERE, base::Bind(callback, base::Passed(&stream), nullptr), {}); } res = SSL_get_error(stream->ssl_.get(), res); if (res == SSL_ERROR_WANT_READ || res == SSL_ERROR_WANT_WRITE) { return RetryAsyncTask( task_runner, FROM_HERE, base::Bind(&SSLStream::DoHandshake, base::Passed(&stream), callback)); } ErrorPtr error; AddSslError(&error, FROM_HERE, "handshake_failed", res); task_runner->PostDelayedTask( FROM_HERE, base::Bind(callback, nullptr, base::Passed(&error)), {}); } } // namespace examples } // namespace weave