| t | import asyncio | t | import asyncio |
| | | |
| class Portal: | | class Portal: |
| | | |
| def __init__(self, parties): | | def __init__(self, parties): |
| self._parties = parties | | self._parties = parties |
| self._count = parties | | self._count = parties |
| self._topic = None | | self._topic = None |
| self._lock = asyncio.Lock() | | self._lock = asyncio.Lock() |
| self._event = asyncio.Event() | | self._event = asyncio.Event() |
| self._results = [None] * parties | | self._results = [None] * parties |
| self._current_index = 0 | | self._current_index = 0 |
| | | |
| async def wait(self, topic=None): | | async def wait(self, topic=None): |
| async with self._lock: | | async with self._lock: |
| if self._count <= 0: | | if self._count <= 0: |
| raise RuntimeError('Barrier broken') | | raise RuntimeError('Barrier broken') |
| index = self._current_index | | index = self._current_index |
| self._current_index += 1 | | self._current_index += 1 |
| if topic is not None: | | if topic is not None: |
| if self._topic is None: | | if self._topic is None: |
| self._topic = topic | | self._topic = topic |
| elif self._topic != topic: | | elif self._topic != topic: |
| raise ValueError('Topic mismatch') | | raise ValueError('Topic mismatch') |
| fut = asyncio.Future() | | fut = asyncio.Future() |
| self._results[index] = fut | | self._results[index] = fut |
| self._count -= 1 | | self._count -= 1 |
| if self._count == 0: | | if self._count == 0: |
| self._event.set() | | self._event.set() |
| for i in range(self._parties): | | for i in range(self._parties): |
| self._results[i].set_result(i) | | self._results[i].set_result(i) |
| if index != self._parties - 1: | | if index != self._parties - 1: |
| await self._event.wait() | | await self._event.wait() |
| return await fut | | return await fut |
| | | |
| async def reset(self): | | async def reset(self): |
| async with self._lock: | | async with self._lock: |
| self._count = self._parties | | self._count = self._parties |
| self._topic = None | | self._topic = None |
| self._event.clear() | | self._event.clear() |
| self._results = [None] * self._parties | | self._results = [None] * self._parties |
| self._current_index = 0 | | self._current_index = 0 |
| | | |
| @property | | @property |
| def parties(self): | | def parties(self): |
| return self._parties | | return self._parties |
| | | |
| @property | | @property |
| def n_waiting(self): | | def n_waiting(self): |
| return self._current_index | | return self._current_index |
| | | |
| @property | | @property |
| def topic(self): | | def topic(self): |
| return self._topic | | return self._topic |