#!/usr/bin/env python
# Copyright 2011 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""System integration test for traffic shaping.

Usage:
$ sudo ./trafficshaper_test.py
"""

import daemonserver
import logging
import platformsettings
import socket
import SocketServer
import trafficshaper
import unittest

RESPONSE_SIZE_KEY = 'response-size:'
TEST_DNS_PORT = 5555
TEST_HTTP_PORT = 8888
TIMER = platformsettings.timer


def GetElapsedMs(start_time, end_time):
  """Return milliseconds elapsed between |start_time| and |end_time|.

  Args:
    start_time: seconds as a float (or string representation of float).
    end_time: seconds as a float (or string representation of float).
  Return:
    milliseconds elapsed as integer.
  """
  return int((float(end_time) - float(start_time)) * 1000)


class TrafficShaperTest(unittest.TestCase):

  def testBadBandwidthRaises(self):
    self.assertRaises(trafficshaper.BandwidthValueError,
                      trafficshaper.TrafficShaper,
                      down_bandwidth='1KBit/s')


class TimedUdpHandler(SocketServer.DatagramRequestHandler):
  """UDP handler that returns the time when the request was handled."""

  def handle(self):
    data = self.rfile.read()
    read_time = self.server.timer()
    self.wfile.write(str(read_time))


class TimedTcpHandler(SocketServer.StreamRequestHandler):
  """Tcp handler that returns the time when the request was read.

  It can respond with the number of bytes specified in the request.
  The request looks like:
    request_data -> RESPONSE_SIZE_KEY num_response_bytes '\n' ANY_DATA
  """

  def handle(self):
    data = self.rfile.read()
    read_time = self.server.timer()
    contents = str(read_time)
    if data.startswith(RESPONSE_SIZE_KEY):
      num_response_bytes = int(data[len(RESPONSE_SIZE_KEY):data.index('\n')])
      contents = '%s\n%s' % (contents,
                             '\x00' * (num_response_bytes - len(contents) - 1))
    self.wfile.write(contents)


class TimedUdpServer(SocketServer.ThreadingUDPServer,
                     daemonserver.DaemonServer):
  """A simple UDP server similar to dnsproxy."""

  # Override SocketServer.TcpServer setting to avoid intermittent errors.
  allow_reuse_address = True

  def __init__(self, host, port, timer=TIMER):
    SocketServer.ThreadingUDPServer.__init__(
        self, (host, port), TimedUdpHandler)
    self.timer = timer

  def cleanup(self):
    pass


class TimedTcpServer(SocketServer.ThreadingTCPServer,
                     daemonserver.DaemonServer):
  """A simple TCP server similar to httpproxy."""

  # Override SocketServer.TcpServer setting to avoid intermittent errors.
  allow_reuse_address = True

  def __init__(self, host, port, timer=TIMER):
    SocketServer.ThreadingTCPServer.__init__(
        self, (host, port), TimedTcpHandler)
    self.timer = timer

  def cleanup(self):
    try:
      self.shutdown()
    except KeyboardInterrupt, e:
      pass


class TcpTestSocketCreator(object):
  """A TCP socket creator suitable for with-statement."""

  def __init__(self, host, port, timeout=1.0):
    self.address = (host, port)
    self.timeout = timeout

  def __enter__(self):
    self.socket = socket.create_connection(self.address, timeout=self.timeout)
    return self.socket

  def __exit__(self, *args):
    self.socket.close()


class TimedTestCase(unittest.TestCase):
  def assertValuesAlmostEqual(self, expected, actual, tolerance=0.05):
    """Like the following with nicer default message:
           assertTrue(expected <= actual + tolerance &&
                      expected >= actual - tolerance)
    """
    delta = tolerance * expected
    if actual > expected + delta or actual < expected - delta:
      self.fail('%s is not equal to expected %s +/- %s%%' % (
              actual, expected, 100 * tolerance))


class TcpTrafficShaperTest(TimedTestCase):

  def setUp(self):
    self.host = platformsettings.get_server_ip_address()
    self.port = TEST_HTTP_PORT
    self.tcp_socket_creator = TcpTestSocketCreator(self.host, self.port)
    self.timer = TIMER

  def TrafficShaper(self, **kwargs):
    return trafficshaper.TrafficShaper(
        host=self.host, ports=(self.port,), **kwargs)

  def GetTcpSendTimeMs(self, num_bytes):
    """Return time in milliseconds to send |num_bytes|."""

    with self.tcp_socket_creator as s:
      start_time = self.timer()
      request_data = '\x00' * num_bytes

      s.sendall(request_data)
      # TODO(slamm): Figure out why partial is shutdown needed to make it work.
      s.shutdown(socket.SHUT_WR)
      read_time = s.recv(1024)
    return GetElapsedMs(start_time, read_time)

  def GetTcpReceiveTimeMs(self, num_bytes):
    """Return time in milliseconds to receive |num_bytes|."""

    with self.tcp_socket_creator as s:
      s.sendall('%s%s\n' % (RESPONSE_SIZE_KEY, num_bytes))
      # TODO(slamm): Figure out why partial is shutdown needed to make it work.
      s.shutdown(socket.SHUT_WR)
      num_remaining_bytes = num_bytes
      read_time = None
      while num_remaining_bytes > 0:
        response_data = s.recv(4096)
        num_remaining_bytes -= len(response_data)
        if not read_time:
          read_time, padding = response_data.split('\n')
    return GetElapsedMs(read_time, self.timer())

  def testTcpConnectToIp(self):
    """Verify that it takes |delay_ms| to establish a TCP connection."""
    if not platformsettings.has_ipfw():
      logging.warning('ipfw is not available in path. Skip the test')
      return
    with TimedTcpServer(self.host, self.port):
      for delay_ms in (100, 175):
        with self.TrafficShaper(delay_ms=delay_ms):
          start_time = self.timer()
          with self.tcp_socket_creator:
            connect_time = GetElapsedMs(start_time, self.timer())
        self.assertValuesAlmostEqual(delay_ms, connect_time, tolerance=0.12)

  def testTcpUploadShaping(self):
    """Verify that 'up' bandwidth is shaped on TCP connections."""
    if not platformsettings.has_ipfw():
      logging.warning('ipfw is not available in path. Skip the test')
      return
    num_bytes = 1024 * 100
    bandwidth_kbits = 2000
    expected_ms = 8.0 * num_bytes / bandwidth_kbits
    with TimedTcpServer(self.host, self.port):
      with self.TrafficShaper(up_bandwidth='%sKbit/s' % bandwidth_kbits):
        self.assertValuesAlmostEqual(expected_ms, self.GetTcpSendTimeMs(num_bytes))

  def testTcpDownloadShaping(self):
    """Verify that 'down' bandwidth is shaped on TCP connections."""
    if not platformsettings.has_ipfw():
      logging.warning('ipfw is not available in path. Skip the test')
      return
    num_bytes = 1024 * 100
    bandwidth_kbits = 2000
    expected_ms = 8.0 * num_bytes / bandwidth_kbits
    with TimedTcpServer(self.host, self.port):
      with self.TrafficShaper(down_bandwidth='%sKbit/s' % bandwidth_kbits):
        self.assertValuesAlmostEqual(expected_ms, self.GetTcpReceiveTimeMs(num_bytes))

  def testTcpInterleavedDownloads(self):
    # TODO(slamm): write tcp interleaved downloads test
    pass


class UdpTrafficShaperTest(TimedTestCase):

  def setUp(self):
    self.host = platformsettings.get_server_ip_address()
    self.dns_port = TEST_DNS_PORT
    self.timer = TIMER

  def TrafficShaper(self, **kwargs):
    return trafficshaper.TrafficShaper(
        host=self.host, ports=(self.dns_port,), **kwargs)

  def GetUdpSendReceiveTimesMs(self):
    """Return time in milliseconds to send |num_bytes|."""
    start_time = self.timer()
    udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    udp_socket.sendto('test data\n', (self.host, self.dns_port))
    read_time = udp_socket.recv(1024)
    return (GetElapsedMs(start_time, read_time),
            GetElapsedMs(read_time, self.timer()))

  def testUdpDelay(self):
    if not platformsettings.has_ipfw():
      logging.warning('ipfw is not available in path. Skip the test')
      return
    for delay_ms in (100, 170):
      expected_ms = delay_ms / 2
      with TimedUdpServer(self.host, self.dns_port):
        with self.TrafficShaper(delay_ms=delay_ms):
          send_ms, receive_ms = self.GetUdpSendReceiveTimesMs()
          self.assertValuesAlmostEqual(expected_ms, send_ms, tolerance=0.10)
          self.assertValuesAlmostEqual(expected_ms, receive_ms, tolerance=0.10)


  def testUdpInterleavedDelay(self):
    # TODO(slamm): write udp interleaved udp delay test
    pass


class TcpAndUdpTrafficShaperTest(TimedTestCase):
  # TODO(slamm): Test concurrent TCP and UDP traffic
  pass


# TODO(slamm): Packet loss rate (try different ports)


if __name__ == '__main__':
  #logging.getLogger().setLevel(logging.DEBUG)
  unittest.main()