diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/storage/events.py | 155 |
1 files changed, 144 insertions, 11 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 0307b2af3c..07e873fce4 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -19,12 +19,14 @@ from twisted.internet import defer, reactor from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events.utils import prune_event +from synapse.util.async import ObservableDeferred from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from canonicaljson import encode_canonical_json -from collections import namedtuple +from collections import deque, namedtuple + import logging import math @@ -50,6 +52,80 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events +class _EventPeristenceQueue(object): + """Queues up events so that they can be persisted in bulk with only one + concurrent transaction per room. + """ + + _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( + "events_and_contexts", "current_state", "backfilled", "deferred", + )) + + def __init__(self): + self._event_persist_queues = {} + self._currently_persisting_rooms = set() + + def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): + """Add events to the queue, with the given persist_event options. + """ + queue = self._event_persist_queues.setdefault(room_id, deque()) + if queue: + end_item = queue[-1] + if end_item.current_state or current_state: + # We perist events with current_state set to True one at a time + pass + if end_item.backfilled == backfilled: + end_item.events_and_contexts.extend(events_and_contexts) + return end_item.deferred.observe() + + deferred = ObservableDeferred(defer.Deferred()) + + queue.append(self._EventPersistQueueItem( + events_and_contexts=events_and_contexts, + backfilled=backfilled, + current_state=current_state, + deferred=deferred, + )) + + return deferred.observe() + + def handle_queue(self, room_id, callback): + """Attempts to handle the queue for a room if not already being handled. + + The given callback will be invoked with a 'queue' arg, which is a + generator over _EventPersistQueueItem's. The queue will finish if there + are no longer any items in the room queue. + + This function should therefore be called whenever anything is added + to the queue. + + If another callback is currently handling the queue then it will not be + invoked. + """ + + if room_id in self._currently_persisting_rooms: + return + + self._currently_persisting_rooms.add(room_id) + + try: + callback(self._get_drainining_queue(room_id)) + finally: + self._currently_persisting_rooms.discard(room_id) + + def _get_drainining_queue(self, room_id): + queue = self._event_persist_queues.pop(room_id, None) + if not queue: + return + + try: + while True: + yield queue.popleft() + except IndexError: + # Queue has been drained. + pass + + class EventsStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" @@ -59,19 +135,80 @@ class EventsStore(SQLBaseStore): self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) - @defer.inlineCallbacks + self._event_persist_queue = _EventPeristenceQueue() + def persist_events(self, events_and_contexts, backfilled=False): """ Write events to the database Args: events_and_contexts: list of tuples of (event, context) backfilled: ? + """ + partitioned = {} + for event, ctx in events_and_contexts: + partitioned.setdefault(event.room_id, []).append((event, ctx)) + + deferreds = [] + for room_id, evs_ctxs in partitioned.items(): + d = self._event_persist_queue.add_to_queue( + room_id, evs_ctxs, + backfilled=backfilled, + current_state=None, + ) + deferreds.append(d) - Returns: Tuple of stream_orderings where the first is the minimum and - last is the maximum stream ordering assigned to the events when - persisting. + for room_id in partitioned.keys(): + self._maybe_start_persisting(room_id) - """ + return defer.gatherResults(deferreds, consumeErrors=True) + + @defer.inlineCallbacks + @log_function + def persist_event(self, event, context, current_state=None, backfilled=False): + deferred = self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], + backfilled=backfilled, + current_state=current_state, + ) + + self._maybe_start_persisting(event.room_id) + + yield deferred + + max_persisted_id = yield self._stream_id_gen.get_current_token() + defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) + + def _maybe_start_persisting(self, room_id): + @defer.inlineCallbacks + def persisting_queue(queue): + for item in queue: + try: + ret = None + if item.current_state: + for event, context in item.events_and_contexts: + # There should only ever be one item in + # events_and_contexts when current_state is + # not None + yield self._persist_event( + event, context, + current_state=item.current_state, + backfilled=item.backfilled, + ) + else: + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) + logger.info("Resolving with ret: %r", ret) + item.deferred.callback(ret) + except Exception as e: + logger.exception("Failed to persist events") + item.deferred.errback(e) + + self._event_persist_queue.handle_queue(room_id, persisting_queue) + + @defer.inlineCallbacks + def _persist_events(self, events_and_contexts, backfilled=False): if not events_and_contexts: return @@ -118,8 +255,7 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function - def persist_event(self, event, context, current_state=None, backfilled=False): - + def _persist_event(self, event, context, current_state=None, backfilled=False): try: with self._stream_id_gen.get_next() as stream_ordering: with self._state_groups_id_gen.get_next() as state_group_id: @@ -136,9 +272,6 @@ class EventsStore(SQLBaseStore): except _RollbackButIsFineException: pass - max_persisted_id = yield self._stream_id_gen.get_current_token() - defer.returnValue((stream_ordering, max_persisted_id)) - @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, |