summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Grönholm <alex.gronholm@nextday.fi>2021-09-12 16:11:13 +0300
committerAlex Grönholm <alex.gronholm@nextday.fi>2021-09-13 01:07:44 +0300
commit9568bf2f1297c87ec1b93306b79de925fb2da08e (patch)
tree2969e7951f8883c41f1359ebbc98bcf85a3fdad6
parent0a6b0f683edee8bf22d85dc655ad61a8285fd312 (diff)
downloadapscheduler-9568bf2f1297c87ec1b93306b79de925fb2da08e.tar.gz
Implemented one-shot event subscriptions
Such subscriptions are delivered the first matching event and then unsubscribed automatically.
-rw-r--r--src/apscheduler/abc.py5
-rw-r--r--src/apscheduler/eventbrokers/async_local.py26
-rw-r--r--src/apscheduler/eventbrokers/base.py14
-rw-r--r--src/apscheduler/eventbrokers/local.py13
-rw-r--r--src/apscheduler/schedulers/sync.py2
-rw-r--r--src/apscheduler/workers/sync.py2
-rw-r--r--tests/test_eventbrokers.py157
7 files changed, 93 insertions, 126 deletions
diff --git a/src/apscheduler/abc.py b/src/apscheduler/abc.py
index 3d85eed..58d74cd 100644
--- a/src/apscheduler/abc.py
+++ b/src/apscheduler/abc.py
@@ -92,13 +92,16 @@ class EventSource(metaclass=ABCMeta):
@abstractmethod
def subscribe(
self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None
+ event_types: Optional[Iterable[type[Event]]] = None,
+ *,
+ one_shot: bool = False
) -> Subscription:
"""
Subscribe to events from this event source.
:param callback: callable to be called with the event object when an event is published
:param event_types: an iterable of concrete Event classes to subscribe to
+ :param one_shot: if ``True``, automatically unsubscribe after the first matching event
"""
diff --git a/src/apscheduler/eventbrokers/async_local.py b/src/apscheduler/eventbrokers/async_local.py
index 590f0cb..79030f3 100644
--- a/src/apscheduler/eventbrokers/async_local.py
+++ b/src/apscheduler/eventbrokers/async_local.py
@@ -36,15 +36,21 @@ class LocalAsyncEventBroker(AsyncEventBroker, BaseEventBroker):
await self.publish_local(event)
async def publish_local(self, event: Event) -> None:
- async def deliver_event(func: Callable[[Event], Any]) -> None:
- try:
- retval = func(event)
- if iscoroutine(retval):
- await retval
- except BaseException:
- self._logger.exception('Error delivering %s event', event.__class__.__name__)
-
event_type = type(event)
- for subscription in self._subscriptions.values():
+ one_shot_tokens: list[object] = []
+ for token, subscription in self._subscriptions.items():
if subscription.event_types is None or event_type in subscription.event_types:
- self._task_group.start_soon(deliver_event, subscription.callback)
+ self._task_group.start_soon(self._deliver_event, subscription.callback, event)
+ if subscription.one_shot:
+ one_shot_tokens.append(subscription.token)
+
+ for token in one_shot_tokens:
+ super().unsubscribe(token)
+
+ async def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None:
+ try:
+ retval = func(event)
+ if iscoroutine(retval):
+ await retval
+ except BaseException:
+ self._logger.exception('Error delivering %s event', event.__class__.__name__)
diff --git a/src/apscheduler/eventbrokers/base.py b/src/apscheduler/eventbrokers/base.py
index 23bae8a..9947f68 100644
--- a/src/apscheduler/eventbrokers/base.py
+++ b/src/apscheduler/eventbrokers/base.py
@@ -16,31 +16,33 @@ from ..exceptions import DeserializationError
class LocalSubscription(Subscription):
callback: Callable[[Event], Any]
event_types: Optional[set[type[Event]]]
+ one_shot: bool
+ token: object
_source: BaseEventBroker
- _token: object
def unsubscribe(self) -> None:
- self._source.unsubscribe(self._token)
+ self._source.unsubscribe(self.token)
@attr.define(eq=False)
class BaseEventBroker(EventBroker):
_logger: Logger = attr.field(init=False)
- _subscriptions: dict[object, Subscription] = attr.field(init=False, factory=dict)
+ _subscriptions: dict[object, LocalSubscription] = attr.field(init=False, factory=dict)
def __attrs_post_init__(self) -> None:
self._logger = getLogger(self.__class__.__module__)
def subscribe(self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None) -> Subscription:
+ event_types: Optional[Iterable[type[Event]]] = None, *,
+ one_shot: bool = False) -> Subscription:
types = set(event_types) if event_types else None
token = object()
- subscription = LocalSubscription(callback, types, self, token)
+ subscription = LocalSubscription(callback, types, one_shot, token, self)
self._subscriptions[token] = subscription
return subscription
def unsubscribe(self, token: object) -> None:
- self._subscriptions.pop(token)
+ self._subscriptions.pop(token, None)
class DistributedEventBrokerMixin:
diff --git a/src/apscheduler/eventbrokers/local.py b/src/apscheduler/eventbrokers/local.py
index 24de3eb..acf0c9a 100644
--- a/src/apscheduler/eventbrokers/local.py
+++ b/src/apscheduler/eventbrokers/local.py
@@ -31,13 +31,14 @@ class LocalEventBroker(BaseEventBroker):
del self._executor
def subscribe(self, callback: Callable[[Event], Any],
- event_types: Optional[Iterable[type[Event]]] = None) -> Subscription:
+ event_types: Optional[Iterable[type[Event]]] = None, *,
+ one_shot: bool = False) -> Subscription:
if iscoroutinefunction(callback):
raise ValueError('Coroutine functions are not supported as callbacks on a synchronous '
'event source')
with self._subscriptions_lock:
- return super().subscribe(callback, event_types)
+ return super().subscribe(callback, event_types, one_shot=one_shot)
def unsubscribe(self, token: object) -> None:
with self._subscriptions_lock:
@@ -49,9 +50,15 @@ class LocalEventBroker(BaseEventBroker):
def publish_local(self, event: Event) -> None:
event_type = type(event)
with self._subscriptions_lock:
- for subscription in self._subscriptions.values():
+ one_shot_tokens: list[object] = []
+ for token, subscription in self._subscriptions.items():
if subscription.event_types is None or event_type in subscription.event_types:
self._executor.submit(self._deliver_event, subscription.callback, event)
+ if subscription.one_shot:
+ one_shot_tokens.append(subscription.token)
+
+ for token in one_shot_tokens:
+ super().unsubscribe(token)
def _deliver_event(self, func: Callable[[Event], Any], event: Event) -> None:
try:
diff --git a/src/apscheduler/schedulers/sync.py b/src/apscheduler/schedulers/sync.py
index dd3f37e..221b284 100644
--- a/src/apscheduler/schedulers/sync.py
+++ b/src/apscheduler/schedulers/sync.py
@@ -73,7 +73,7 @@ class Scheduler:
# Start the scheduler and return when it has signalled readiness or raised an exception
start_future: Future[Event] = Future()
- with self._events.subscribe(start_future.set_result):
+ with self._events.subscribe(start_future.set_result, one_shot=True):
run_future = self._executor.submit(self.run)
wait([start_future, run_future], return_when=FIRST_COMPLETED)
diff --git a/src/apscheduler/workers/sync.py b/src/apscheduler/workers/sync.py
index 824cce8..be805ee 100644
--- a/src/apscheduler/workers/sync.py
+++ b/src/apscheduler/workers/sync.py
@@ -64,7 +64,7 @@ class Worker:
# Start the worker and return when it has signalled readiness or raised an exception
start_future: Future[None] = Future()
- with self._events.subscribe(start_future.set_result):
+ with self._events.subscribe(start_future.set_result, one_shot=True):
self._executor = ThreadPoolExecutor(1)
run_future = self._executor.submit(self.run)
wait([start_future, run_future], return_when=FIRST_COMPLETED)
diff --git a/tests/test_eventbrokers.py b/tests/test_eventbrokers.py
index 7097001..024b63d 100644
--- a/tests/test_eventbrokers.py
+++ b/tests/test_eventbrokers.py
@@ -79,16 +79,10 @@ def async_broker(request: FixtureRequest) -> Callable[[], AsyncEventBroker]:
class TestEventBroker:
def test_publish_subscribe(self, broker: EventBroker) -> None:
- def subscriber1(event) -> None:
- queue.put_nowait(event)
-
- def subscriber2(event) -> None:
- queue.put_nowait(event)
-
queue = Queue()
with broker:
- broker.subscribe(subscriber1)
- broker.subscribe(subscriber2)
+ broker.subscribe(queue.put_nowait)
+ broker.subscribe(queue.put_nowait)
event = ScheduleAdded(
schedule_id='schedule1',
next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc))
@@ -102,6 +96,25 @@ class TestEventBroker:
assert event1.schedule_id == 'schedule1'
assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)
+ def test_subscribe_one_shot(self, broker: EventBroker) -> None:
+ queue = Queue()
+ with broker:
+ broker.subscribe(queue.put_nowait, one_shot=True)
+ event = ScheduleAdded(
+ schedule_id='schedule1',
+ next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc))
+ broker.publish(event)
+ event = ScheduleAdded(
+ schedule_id='schedule2',
+ next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc))
+ broker.publish(event)
+ received_event = queue.get(timeout=3)
+ with pytest.raises(Empty):
+ queue.get(timeout=0.1)
+
+ assert isinstance(received_event, ScheduleAdded)
+ assert received_event.schedule_id == 'schedule1'
+
def test_unsubscribe(self, broker: EventBroker, caplog) -> None:
queue = Queue()
with broker:
@@ -140,24 +153,18 @@ class TestEventBroker:
@pytest.mark.anyio
class TestAsyncEventBroker:
async def test_publish_subscribe(self, async_broker: AsyncEventBroker) -> None:
- def subscriber1(event) -> None:
- send.send_nowait(event)
-
- async def subscriber2(event) -> None:
- await send.send(event)
-
send, receive = create_memory_object_stream(2)
async with async_broker:
- async_broker.subscribe(subscriber1)
- async_broker.subscribe(subscriber2)
+ async_broker.subscribe(send.send)
+ async_broker.subscribe(send.send_nowait)
event = ScheduleAdded(
schedule_id='schedule1',
next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc))
await async_broker.publish(event)
- with fail_after(3):
- event1 = await receive.receive()
- event2 = await receive.receive()
+ with fail_after(3):
+ event1 = await receive.receive()
+ event2 = await receive.receive()
assert event1 == event2
assert isinstance(event1, ScheduleAdded)
@@ -165,6 +172,28 @@ class TestAsyncEventBroker:
assert event1.schedule_id == 'schedule1'
assert event1.next_fire_time == datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc)
+ async def test_subscribe_one_shot(self, async_broker: AsyncEventBroker) -> None:
+ send, receive = create_memory_object_stream(2)
+ async with async_broker:
+ async_broker.subscribe(send.send, one_shot=True)
+ event = ScheduleAdded(
+ schedule_id='schedule1',
+ next_fire_time=datetime(2021, 9, 11, 12, 31, 56, 254867, timezone.utc))
+ await async_broker.publish(event)
+ event = ScheduleAdded(
+ schedule_id='schedule2',
+ next_fire_time=datetime(2021, 9, 12, 8, 42, 11, 968481, timezone.utc))
+ await async_broker.publish(event)
+
+ with fail_after(3):
+ received_event = await receive.receive()
+
+ with pytest.raises(TimeoutError), fail_after(0.1):
+ await receive.receive()
+
+ assert isinstance(received_event, ScheduleAdded)
+ assert received_event.schedule_id == 'schedule1'
+
async def test_unsubscribe(self, async_broker: AsyncEventBroker) -> None:
send, receive = create_memory_object_stream()
async with async_broker:
@@ -191,92 +220,12 @@ class TestAsyncEventBroker:
raise Exception('foo')
timestamp = datetime.now(timezone.utc)
- events = []
+ send, receive = create_memory_object_stream()
async with async_broker:
async_broker.subscribe(bad_subscriber)
- async_broker.subscribe(events.append)
+ async_broker.subscribe(send.send)
await async_broker.publish(Event(timestamp=timestamp))
- assert isinstance(events[0], Event)
- assert events[0].timestamp == timestamp
- assert 'Error delivering Event' in caplog.text
-#
-# def test_subscribe_coroutine_callback(self) -> None:
-# async def callback(event: Event) -> None:
-# pass
-#
-# with EventBroker() as eventhub:
-# with pytest.raises(ValueError, match='Coroutine functions are not supported'):
-# eventhub.subscribe(callback)
-#
-# def test_relay_events(self) -> None:
-# timestamp = datetime.now(timezone.utc)
-# events = []
-# with EventBroker() as eventhub1, EventBroker() as eventhub2:
-# eventhub2.relay_events_from(eventhub1)
-# eventhub2.subscribe(events.append)
-# eventhub1.publish(Event(timestamp=timestamp))
-#
-# assert isinstance(events[0], Event)
-# assert events[0].timestamp == timestamp
-#
-#
-# @pytest.mark.anyio
-# class TestAsyncEventHub:
-# async def test_publish(self) -> None:
-# async def async_setitem(event: Event) -> None:
-# events[1] = event
-#
-# timestamp = datetime.now(timezone.utc)
-# events: List[Optional[Event]] = [None, None]
-# async with AsyncEventBroker() as eventhub:
-# eventhub.subscribe(partial(setitem, events, 0))
-# eventhub.subscribe(async_setitem)
-# eventhub.publish(Event(timestamp=timestamp))
-#
-# assert events[0] is events[1]
-# assert isinstance(events[0], Event)
-# assert events[0].timestamp == timestamp
-#
-# async def test_unsubscribe(self) -> None:
-# timestamp = datetime.now(timezone.utc)
-# events = []
-# async with AsyncEventBroker() as eventhub:
-# token = eventhub.subscribe(events.append)
-# eventhub.publish(Event(timestamp=timestamp))
-# eventhub.unsubscribe(token)
-# eventhub.publish(Event(timestamp=timestamp))
-#
-# assert len(events) == 1
-#
-# async def test_publish_no_subscribers(self, caplog: LogCaptureFixture) -> None:
-# async with AsyncEventBroker() as eventhub:
-# eventhub.publish(Event(timestamp=datetime.now(timezone.utc)))
-#
-# assert not caplog.text
-#
-# async def test_publish_exception(self, caplog: LogCaptureFixture) -> None:
-# def bad_subscriber(event: Event) -> None:
-# raise Exception('foo')
-#
-# timestamp = datetime.now(timezone.utc)
-# events = []
-# async with AsyncEventBroker() as eventhub:
-# eventhub.subscribe(bad_subscriber)
-# eventhub.subscribe(events.append)
-# eventhub.publish(Event(timestamp=timestamp))
-#
-# assert isinstance(events[0], Event)
-# assert events[0].timestamp == timestamp
-# assert 'Error delivering Event' in caplog.text
-#
-# async def test_relay_events(self) -> None:
-# timestamp = datetime.now(timezone.utc)
-# events = []
-# async with AsyncEventBroker() as eventhub1, AsyncEventBroker() as eventhub2:
-# eventhub1.relay_events_from(eventhub2)
-# eventhub1.subscribe(events.append)
-# eventhub2.publish(Event(timestamp=timestamp))
-#
-# assert isinstance(events[0], Event)
-# assert events[0].timestamp == timestamp
+ received_event = await receive.receive()
+ assert received_event.timestamp == timestamp
+ assert 'Error delivering Event' in caplog.text