# Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

from multiprocessing import Queue, queues


class QueueBarrierTimeout(Exception):
    """QueueBarrier timeout exception."""


class QueueBarrier(object):
    """This class implements a simple barrier to synchronize processes. The
    barrier relies on the fact that there a single process "master" and |n|
    different "slaves" to make the implementation simpler. Also, given this
    hierarchy, the slaves and the master can exchange a token while passing
    through the barrier.

    The so called "master" shall call master_barrier() while the "slave" shall
    call the slave_barrier() method.

    If the same group of |n| slaves and the same master are participating in the
    barrier, it is totally safe to reuse the barrier several times with the same
    group of processes.
    """


    def __init__(self, n):
        """Initializes the barrier with |n| slave processes and a master.

        @param n: The number of slave processes."""
        self.n_ = n
        self.queue_master_ = Queue()
        self.queue_slave_ = Queue()


    def master_barrier(self, token=None, timeout=None):
        """Makes the master wait until all the "n" slaves have reached this
        point.

        @param token: A value passed to every slave.
        @param timeout: The timeout, in seconds, to wait for the slaves.
                A None value will block forever.

        Returns the list of received tokens from the slaves.
        """
        # Wait for all the slaves.
        result = []
        try:
            for _ in range(self.n_):
                result.append(self.queue_master_.get(timeout=timeout))
        except queues.Empty:
            # Timeout expired
            raise QueueBarrierTimeout()
        # Release all the blocked slaves.
        for _ in range(self.n_):
            self.queue_slave_.put(token)
        return result


    def slave_barrier(self, token=None, timeout=None):
        """Makes a slave wait until all the "n" slaves and the master have
        reached this point.

        @param token: A value passed to the master.
        @param timeout: The timeout, in seconds, to wait for the slaves.
                A None value will block forever.
        """
        self.queue_master_.put(token)
        try:
            return self.queue_slave_.get(timeout=timeout)
        except queues.Empty:
            # Timeout expired
            raise QueueBarrierTimeout()