普通文本  |  316行  |  8.06 KB

// Copyright 2012 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "polo/pairing/pairingsession.h"

#include <glog/logging.h>
#include "polo/encoding/hexadecimalencoder.h"
#include "polo/util/poloutil.h"

namespace polo {
namespace pairing {

PairingSession::PairingSession(wire::PoloWireAdapter* wire,
                               PairingContext* context,
                               PoloChallengeResponse* challenge)
    : state_(kUninitialized),
      wire_(wire),
      context_(context),
      challenge_(challenge),
      configuration_(NULL),
      encoder_(NULL),
      nonce_(NULL),
      secret_(NULL) {
  wire_->set_listener(this);

  local_options_.set_protocol_role_preference(context->is_server() ?
      message::OptionsMessage::kDisplayDevice
      : message::OptionsMessage::kInputDevice);
}

PairingSession::~PairingSession() {
  if (configuration_) {
    delete configuration_;
  }

  if (encoder_) {
    delete encoder_;
  }

  if (nonce_) {
    delete nonce_;
  }

  if (secret_) {
    delete secret_;
  }
}

void PairingSession::AddInputEncoding(
    const encoding::EncodingOption& encoding) {
  if (state_ != kUninitialized) {
    LOG(ERROR) << "Attempt to add input encoding to active session";
    return;
  }

  if (!IsValidEncodingOption(encoding)) {
    LOG(ERROR) << "Invalid input encoding: " << encoding.ToString();
    return;
  }

  local_options_.AddInputEncoding(encoding);
}

void PairingSession::AddOutputEncoding(
    const encoding::EncodingOption& encoding) {
  if (state_ != kUninitialized) {
    LOG(ERROR) << "Attempt to add output encoding to active session";
    return;
  }

  if (!IsValidEncodingOption(encoding)) {
    LOG(ERROR) << "Invalid output encoding: " << encoding.ToString();
    return;
  }

  local_options_.AddOutputEncoding(encoding);
}

bool PairingSession::SetSecret(const Gamma& secret) {
  secret_ = new Gamma(secret);

  if (!IsInputDevice() || state_ != kWaitingForSecret) {
    LOG(ERROR) << "Invalid state: unexpected secret";
    return false;
  }

  if (!challenge().CheckGamma(secret)) {
    LOG(ERROR) << "Secret failed local check";
    return false;
  }

  nonce_ = challenge().ExtractNonce(secret);
  if (!nonce_) {
    LOG(ERROR) << "Failed to extract nonce";
    return false;
  }

  const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
  if (!gen_alpha) {
    LOG(ERROR) << "Failed to get alpha";
    return false;
  }

  message::SecretMessage secret_message(*gen_alpha);
  delete gen_alpha;

  wire_->SendSecretMessage(secret_message);

  LOG(INFO) << "Waiting for SecretAck...";
  wire_->GetNextMessage();

  return true;
}

void PairingSession::DoPair(PairingListener *listener) {
  listener_ = listener;
  listener_->OnSessionCreated();

  if (context_->is_server()) {
    LOG(INFO) << "Pairing started (SERVER mode)";
  } else {
    LOG(INFO) << "Pairing started (CLIENT mode)";
  }
  LOG(INFO) << "Local options: " << local_options_.ToString();

  set_state(kInitializing);
  DoInitializationPhase();
}

void PairingSession::DoPairingPhase() {
  if (IsInputDevice()) {
    DoInputPairing();
  } else {
    DoOutputPairing();
  }
}

void PairingSession::DoInputPairing() {
  set_state(kWaitingForSecret);
  listener_->OnPerformInputDeviceRole();
}

void PairingSession::DoOutputPairing() {
  size_t nonce_length = configuration_->encoding().symbol_length() / 2;
  size_t bytes_needed = nonce_length / encoder_->symbols_per_byte();

  uint8_t* random = util::PoloUtil::GenerateRandomBytes(bytes_needed);
  nonce_ = new Nonce(random, random + bytes_needed);
  delete[] random;

  const Gamma* gamma = challenge().GetGamma(*nonce_);
  if (!gamma) {
    LOG(ERROR) << "Failed to get gamma";
    wire()->SendErrorMessage(kErrorProtocol);
    listener()->OnError(kErrorProtocol);
    return;
  }

  listener_->OnPerformOutputDeviceRole(*gamma);
  delete gamma;

  set_state(kWaitingForSecret);

  LOG(INFO) << "Waiting for Secret...";
  wire_->GetNextMessage();
}

void PairingSession::set_state(ProtocolState state) {
  LOG(INFO) << "New state: " << state;
  state_ = state;
}

bool PairingSession::SetConfiguration(
    const message::ConfigurationMessage& message) {
  const encoding::EncodingOption& encoding = message.encoding();

  if (!IsValidEncodingOption(encoding)) {
    LOG(ERROR) << "Invalid configuration: " << encoding.ToString();
    return false;
  }

  if (encoder_) {
    delete encoder_;
    encoder_ = NULL;
  }

  switch (encoding.encoding_type()) {
    case encoding::EncodingOption::kHexadecimal:
      encoder_ = new encoding::HexadecimalEncoder();
      break;
    default:
      LOG(ERROR) << "Unsupported encoding type: "
          << encoding.encoding_type();
      return false;
  }

  if (configuration_) {
    delete configuration_;
  }
  configuration_ = new message::ConfigurationMessage(message.encoding(),
                                                     message.client_role());
  return true;
}

void PairingSession::OnSecretMessage(const message::SecretMessage& message) {
  if (state() != kWaitingForSecret) {
    LOG(ERROR) << "Invalid state: unexpected secret message";
    wire()->SendErrorMessage(kErrorProtocol);
    listener()->OnError(kErrorProtocol);
    return;
  }

  if (!VerifySecret(message.secret())) {
    wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
    listener_->OnError(kErrorInvalidChallengeResponse);
    return;
  }

  const Alpha* alpha = challenge().GetAlpha(*nonce_);
  if (!alpha) {
    LOG(ERROR) << "Failed to get alpha";
    wire()->SendErrorMessage(kErrorProtocol);
    listener()->OnError(kErrorProtocol);
    return;
  }

  message::SecretAckMessage ack(*alpha);
  delete alpha;

  wire_->SendSecretAckMessage(ack);

  listener_->OnPairingSuccess();
}

void PairingSession::OnSecretAckMessage(
    const message::SecretAckMessage& message) {
  if (kVerifySecretAck && !VerifySecret(message.secret())) {
    wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
    listener_->OnError(kErrorInvalidChallengeResponse);
    return;
  }

  listener_->OnPairingSuccess();
}

void PairingSession::OnError(pairing::PoloError error) {
  listener_->OnError(error);
}

bool PairingSession::VerifySecret(const Alpha& secret) const {
  if (!nonce_) {
    LOG(ERROR) << "Nonce not set";
    return false;
  }

  const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
  if (!gen_alpha) {
    LOG(ERROR) << "Failed to get alpha";
    return false;
  }

  bool valid = (secret == *gen_alpha);

  if (!valid) {
    LOG(ERROR) << "Inband secret did not match. Expected ["
        << util::PoloUtil::BytesToHexString(&(*gen_alpha)[0], gen_alpha->size())
        << "], got ["
        << util::PoloUtil::BytesToHexString(&secret[0], secret.size())
        << "]";
  }

  delete gen_alpha;
  return valid;
}

message::OptionsMessage::ProtocolRole PairingSession::GetLocalRole() const {
  if (!configuration_) {
    return message::OptionsMessage::kUnknown;
  }

  if (context_->is_client()) {
    return configuration_->client_role();
  } else {
    return configuration_->client_role() ==
        message::OptionsMessage::kDisplayDevice ?
            message::OptionsMessage::kInputDevice
            : message::OptionsMessage::kDisplayDevice;
  }
}

bool PairingSession::IsInputDevice() const {
  return GetLocalRole() == message::OptionsMessage::kInputDevice;
}

bool PairingSession::IsValidEncodingOption(
    const encoding::EncodingOption& option) const {
  // Legal values of GAMMALEN must be an even number of at least 2 bytes.
  return option.encoding_type() != encoding::EncodingOption::kUnknown
      && (option.symbol_length() % 2 == 0)
      && (option.symbol_length() >= 2);
}

}  // namespace pairing
}  // namespace polo