// Copyright 2012 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "polo/pairing/pairingsession.h" #include <glog/logging.h> #include "polo/encoding/hexadecimalencoder.h" #include "polo/util/poloutil.h" namespace polo { namespace pairing { PairingSession::PairingSession(wire::PoloWireAdapter* wire, PairingContext* context, PoloChallengeResponse* challenge) : state_(kUninitialized), wire_(wire), context_(context), challenge_(challenge), configuration_(NULL), encoder_(NULL), nonce_(NULL), secret_(NULL) { wire_->set_listener(this); local_options_.set_protocol_role_preference(context->is_server() ? message::OptionsMessage::kDisplayDevice : message::OptionsMessage::kInputDevice); } PairingSession::~PairingSession() { if (configuration_) { delete configuration_; } if (encoder_) { delete encoder_; } if (nonce_) { delete nonce_; } if (secret_) { delete secret_; } } void PairingSession::AddInputEncoding( const encoding::EncodingOption& encoding) { if (state_ != kUninitialized) { LOG(ERROR) << "Attempt to add input encoding to active session"; return; } if (!IsValidEncodingOption(encoding)) { LOG(ERROR) << "Invalid input encoding: " << encoding.ToString(); return; } local_options_.AddInputEncoding(encoding); } void PairingSession::AddOutputEncoding( const encoding::EncodingOption& encoding) { if (state_ != kUninitialized) { LOG(ERROR) << "Attempt to add output encoding to active session"; return; } if (!IsValidEncodingOption(encoding)) { LOG(ERROR) << "Invalid output encoding: " << encoding.ToString(); return; } local_options_.AddOutputEncoding(encoding); } bool PairingSession::SetSecret(const Gamma& secret) { secret_ = new Gamma(secret); if (!IsInputDevice() || state_ != kWaitingForSecret) { LOG(ERROR) << "Invalid state: unexpected secret"; return false; } if (!challenge().CheckGamma(secret)) { LOG(ERROR) << "Secret failed local check"; return false; } nonce_ = challenge().ExtractNonce(secret); if (!nonce_) { LOG(ERROR) << "Failed to extract nonce"; return false; } const Alpha* gen_alpha = challenge().GetAlpha(*nonce_); if (!gen_alpha) { LOG(ERROR) << "Failed to get alpha"; return false; } message::SecretMessage secret_message(*gen_alpha); delete gen_alpha; wire_->SendSecretMessage(secret_message); LOG(INFO) << "Waiting for SecretAck..."; wire_->GetNextMessage(); return true; } void PairingSession::DoPair(PairingListener *listener) { listener_ = listener; listener_->OnSessionCreated(); if (context_->is_server()) { LOG(INFO) << "Pairing started (SERVER mode)"; } else { LOG(INFO) << "Pairing started (CLIENT mode)"; } LOG(INFO) << "Local options: " << local_options_.ToString(); set_state(kInitializing); DoInitializationPhase(); } void PairingSession::DoPairingPhase() { if (IsInputDevice()) { DoInputPairing(); } else { DoOutputPairing(); } } void PairingSession::DoInputPairing() { set_state(kWaitingForSecret); listener_->OnPerformInputDeviceRole(); } void PairingSession::DoOutputPairing() { size_t nonce_length = configuration_->encoding().symbol_length() / 2; size_t bytes_needed = nonce_length / encoder_->symbols_per_byte(); uint8_t* random = util::PoloUtil::GenerateRandomBytes(bytes_needed); nonce_ = new Nonce(random, random + bytes_needed); delete[] random; const Gamma* gamma = challenge().GetGamma(*nonce_); if (!gamma) { LOG(ERROR) << "Failed to get gamma"; wire()->SendErrorMessage(kErrorProtocol); listener()->OnError(kErrorProtocol); return; } listener_->OnPerformOutputDeviceRole(*gamma); delete gamma; set_state(kWaitingForSecret); LOG(INFO) << "Waiting for Secret..."; wire_->GetNextMessage(); } void PairingSession::set_state(ProtocolState state) { LOG(INFO) << "New state: " << state; state_ = state; } bool PairingSession::SetConfiguration( const message::ConfigurationMessage& message) { const encoding::EncodingOption& encoding = message.encoding(); if (!IsValidEncodingOption(encoding)) { LOG(ERROR) << "Invalid configuration: " << encoding.ToString(); return false; } if (encoder_) { delete encoder_; encoder_ = NULL; } switch (encoding.encoding_type()) { case encoding::EncodingOption::kHexadecimal: encoder_ = new encoding::HexadecimalEncoder(); break; default: LOG(ERROR) << "Unsupported encoding type: " << encoding.encoding_type(); return false; } if (configuration_) { delete configuration_; } configuration_ = new message::ConfigurationMessage(message.encoding(), message.client_role()); return true; } void PairingSession::OnSecretMessage(const message::SecretMessage& message) { if (state() != kWaitingForSecret) { LOG(ERROR) << "Invalid state: unexpected secret message"; wire()->SendErrorMessage(kErrorProtocol); listener()->OnError(kErrorProtocol); return; } if (!VerifySecret(message.secret())) { wire()->SendErrorMessage(kErrorInvalidChallengeResponse); listener_->OnError(kErrorInvalidChallengeResponse); return; } const Alpha* alpha = challenge().GetAlpha(*nonce_); if (!alpha) { LOG(ERROR) << "Failed to get alpha"; wire()->SendErrorMessage(kErrorProtocol); listener()->OnError(kErrorProtocol); return; } message::SecretAckMessage ack(*alpha); delete alpha; wire_->SendSecretAckMessage(ack); listener_->OnPairingSuccess(); } void PairingSession::OnSecretAckMessage( const message::SecretAckMessage& message) { if (kVerifySecretAck && !VerifySecret(message.secret())) { wire()->SendErrorMessage(kErrorInvalidChallengeResponse); listener_->OnError(kErrorInvalidChallengeResponse); return; } listener_->OnPairingSuccess(); } void PairingSession::OnError(pairing::PoloError error) { listener_->OnError(error); } bool PairingSession::VerifySecret(const Alpha& secret) const { if (!nonce_) { LOG(ERROR) << "Nonce not set"; return false; } const Alpha* gen_alpha = challenge().GetAlpha(*nonce_); if (!gen_alpha) { LOG(ERROR) << "Failed to get alpha"; return false; } bool valid = (secret == *gen_alpha); if (!valid) { LOG(ERROR) << "Inband secret did not match. Expected [" << util::PoloUtil::BytesToHexString(&(*gen_alpha)[0], gen_alpha->size()) << "], got [" << util::PoloUtil::BytesToHexString(&secret[0], secret.size()) << "]"; } delete gen_alpha; return valid; } message::OptionsMessage::ProtocolRole PairingSession::GetLocalRole() const { if (!configuration_) { return message::OptionsMessage::kUnknown; } if (context_->is_client()) { return configuration_->client_role(); } else { return configuration_->client_role() == message::OptionsMessage::kDisplayDevice ? message::OptionsMessage::kInputDevice : message::OptionsMessage::kDisplayDevice; } } bool PairingSession::IsInputDevice() const { return GetLocalRole() == message::OptionsMessage::kInputDevice; } bool PairingSession::IsValidEncodingOption( const encoding::EncodingOption& option) const { // Legal values of GAMMALEN must be an even number of at least 2 bytes. return option.encoding_type() != encoding::EncodingOption::kUnknown && (option.symbol_length() % 2 == 0) && (option.symbol_length() >= 2); } } // namespace pairing } // namespace polo