summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/storage/events.py155
1 files changed, 144 insertions, 11 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 0307b2af3c..07e873fce4 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,12 +19,14 @@ from twisted.internet import defer, reactor
 from synapse.events import FrozenEvent, USE_FROZEN_DICTS
 from synapse.events.utils import prune_event
 
+from synapse.util.async import ObservableDeferred
 from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
 from synapse.util.logutils import log_function
 from synapse.api.constants import EventTypes
 
 from canonicaljson import encode_canonical_json
-from collections import namedtuple
+from collections import deque, namedtuple
+
 
 import logging
 import math
@@ -50,6 +52,80 @@ EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 
 
+class _EventPeristenceQueue(object):
+    """Queues up events so that they can be persisted in bulk with only one
+    concurrent transaction per room.
+    """
+
+    _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
+        "events_and_contexts", "current_state", "backfilled", "deferred",
+    ))
+
+    def __init__(self):
+        self._event_persist_queues = {}
+        self._currently_persisting_rooms = set()
+
+    def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
+        """Add events to the queue, with the given persist_event options.
+        """
+        queue = self._event_persist_queues.setdefault(room_id, deque())
+        if queue:
+            end_item = queue[-1]
+            if end_item.current_state or current_state:
+                # We perist events with current_state set to True one at a time
+                pass
+            if end_item.backfilled == backfilled:
+                end_item.events_and_contexts.extend(events_and_contexts)
+                return end_item.deferred.observe()
+
+        deferred = ObservableDeferred(defer.Deferred())
+
+        queue.append(self._EventPersistQueueItem(
+            events_and_contexts=events_and_contexts,
+            backfilled=backfilled,
+            current_state=current_state,
+            deferred=deferred,
+        ))
+
+        return deferred.observe()
+
+    def handle_queue(self, room_id, callback):
+        """Attempts to handle the queue for a room if not already being handled.
+
+        The given callback will be invoked with a 'queue' arg, which is a
+        generator over _EventPersistQueueItem's. The queue will finish if there
+        are no longer any items in the room queue.
+
+        This function should therefore be called whenever anything is added
+        to the queue.
+
+        If another callback is currently handling the queue then it will not be
+        invoked.
+        """
+
+        if room_id in self._currently_persisting_rooms:
+            return
+
+        self._currently_persisting_rooms.add(room_id)
+
+        try:
+            callback(self._get_drainining_queue(room_id))
+        finally:
+            self._currently_persisting_rooms.discard(room_id)
+
+    def _get_drainining_queue(self, room_id):
+        queue = self._event_persist_queues.pop(room_id, None)
+        if not queue:
+            return
+
+        try:
+            while True:
+                yield queue.popleft()
+        except IndexError:
+            # Queue has been drained.
+            pass
+
+
 class EventsStore(SQLBaseStore):
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
 
@@ -59,19 +135,80 @@ class EventsStore(SQLBaseStore):
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
 
-    @defer.inlineCallbacks
+        self._event_persist_queue = _EventPeristenceQueue()
+
     def persist_events(self, events_and_contexts, backfilled=False):
         """
         Write events to the database
         Args:
             events_and_contexts: list of tuples of (event, context)
             backfilled: ?
+        """
+        partitioned = {}
+        for event, ctx in events_and_contexts:
+            partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+        deferreds = []
+        for room_id, evs_ctxs in partitioned.items():
+            d = self._event_persist_queue.add_to_queue(
+                room_id, evs_ctxs,
+                backfilled=backfilled,
+                current_state=None,
+            )
+            deferreds.append(d)
 
-        Returns: Tuple of stream_orderings where the first is the minimum and
-            last is the maximum stream ordering assigned to the events when
-            persisting.
+        for room_id in partitioned.keys():
+            self._maybe_start_persisting(room_id)
 
-        """
+        return defer.gatherResults(deferreds, consumeErrors=True)
+
+    @defer.inlineCallbacks
+    @log_function
+    def persist_event(self, event, context, current_state=None, backfilled=False):
+        deferred = self._event_persist_queue.add_to_queue(
+            event.room_id, [(event, context)],
+            backfilled=backfilled,
+            current_state=current_state,
+        )
+
+        self._maybe_start_persisting(event.room_id)
+
+        yield deferred
+
+        max_persisted_id = yield self._stream_id_gen.get_current_token()
+        defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
+
+    def _maybe_start_persisting(self, room_id):
+        @defer.inlineCallbacks
+        def persisting_queue(queue):
+            for item in queue:
+                try:
+                    ret = None
+                    if item.current_state:
+                        for event, context in item.events_and_contexts:
+                            # There should only ever be one item in
+                            # events_and_contexts when current_state is
+                            # not None
+                            yield self._persist_event(
+                                event, context,
+                                current_state=item.current_state,
+                                backfilled=item.backfilled,
+                            )
+                    else:
+                        yield self._persist_events(
+                            item.events_and_contexts,
+                            backfilled=item.backfilled,
+                        )
+                    logger.info("Resolving with ret: %r", ret)
+                    item.deferred.callback(ret)
+                except Exception as e:
+                    logger.exception("Failed to persist events")
+                    item.deferred.errback(e)
+
+        self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+    @defer.inlineCallbacks
+    def _persist_events(self, events_and_contexts, backfilled=False):
         if not events_and_contexts:
             return
 
@@ -118,8 +255,7 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, context, current_state=None, backfilled=False):
-
+    def _persist_event(self, event, context, current_state=None, backfilled=False):
         try:
             with self._stream_id_gen.get_next() as stream_ordering:
                 with self._state_groups_id_gen.get_next() as state_group_id:
@@ -136,9 +272,6 @@ class EventsStore(SQLBaseStore):
         except _RollbackButIsFineException:
             pass
 
-        max_persisted_id = yield self._stream_id_gen.get_current_token()
-        defer.returnValue((stream_ordering, max_persisted_id))
-
     @defer.inlineCallbacks
     def get_event(self, event_id, check_redacted=True,
                   get_prev_content=False, allow_rejected=False,