summary refs log tree commit diff
path: root/synapse/storage/persist_events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/persist_events.py')
-rw-r--r--synapse/storage/persist_events.py96
1 files changed, 83 insertions, 13 deletions
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 4d2d88d1f0..70e636b0ba 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -96,7 +96,9 @@ class _EventPeristenceQueue:
 
         Returns:
             defer.Deferred: a deferred which will resolve once the events are
-                persisted. Runs its callbacks *without* a logcontext.
+            persisted. Runs its callbacks *without* a logcontext. The result
+            is the same as that returned by the callback passed to
+            `handle_queue`.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
@@ -199,7 +201,7 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> RoomStreamToken:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """
         Write events to the database
         Args:
@@ -209,7 +211,11 @@ class EventsPersistenceStorage:
                 which might update the current state etc.
 
         Returns:
-            the stream ordering of the latest persisted event
+            List of events persisted, the current position room stream position.
+            The list of events persisted may not be the same as those passed in
+            if they were deduplicated due to an event already existing that
+            matched the transcation ID; the existing event is returned in such
+            a case.
         """
         partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
         for event, ctx in events_and_contexts:
@@ -225,19 +231,41 @@ class EventsPersistenceStorage:
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        await make_deferred_yieldable(
+        # Each deferred returns a map from event ID to existing event ID if the
+        # event was deduplicated. (The dict may also include other entries if
+        # the event was persisted in a batch with other events).
+        #
+        # Since we use `defer.gatherResults` we need to merge the returned list
+        # of dicts into one.
+        ret_vals = await make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
+        replaced_events = {}
+        for d in ret_vals:
+            replaced_events.update(d)
+
+        events = []
+        for event, _ in events_and_contexts:
+            existing_event_id = replaced_events.get(event.event_id)
+            if existing_event_id:
+                events.append(await self.main_store.get_event(existing_event_id))
+            else:
+                events.append(event)
 
-        return self.main_store.get_room_max_token()
+        return (
+            events,
+            self.main_store.get_room_max_token(),
+        )
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
+    ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
-            The stream ordering of `event`, and the stream ordering of the
-            latest persisted event
+            The event, stream ordering of `event`, and the stream ordering of the
+            latest persisted event. The returned event may not match the given
+            event if it was deduplicated due to an existing event matching the
+            transaction ID.
         """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)], backfilled=backfilled
@@ -245,19 +273,33 @@ class EventsPersistenceStorage:
 
         self._maybe_start_persisting(event.room_id)
 
-        await make_deferred_yieldable(deferred)
+        # The deferred returns a map from event ID to existing event ID if the
+        # event was deduplicated. (The dict may also include other entries if
+        # the event was persisted in a batch with other events.)
+        replaced_events = await make_deferred_yieldable(deferred)
+        replaced_event = replaced_events.get(event.event_id)
+        if replaced_event:
+            event = await self.main_store.get_event(replaced_event)
 
         event_stream_id = event.internal_metadata.stream_ordering
         # stream ordering should have been assigned by now
         assert event_stream_id
 
         pos = PersistedEventPosition(self._instance_name, event_stream_id)
-        return pos, self.main_store.get_room_max_token()
+        return event, pos, self.main_store.get_room_max_token()
 
     def _maybe_start_persisting(self, room_id: str):
+        """Pokes the `_event_persist_queue` to start handling new items in the
+        queue, if not already in progress.
+
+        Causes the deferreds returned by `add_to_queue` to resolve with: a
+        dictionary of event ID to event ID we didn't persist as we already had
+        another event persisted with the same TXN ID.
+        """
+
         async def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
-                await self._persist_events(
+                return await self._persist_events(
                     item.events_and_contexts, backfilled=item.backfilled
                 )
 
@@ -267,12 +309,38 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ):
+    ) -> Dict[str, str]:
         """Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
+
+        Returns:
+            A dictionary of event ID to event ID we didn't persist as we already
+            had another event persisted with the same TXN ID.
         """
+        replaced_events = {}  # type: Dict[str, str]
         if not events_and_contexts:
-            return
+            return replaced_events
+
+        # Check if any of the events have a transaction ID that has already been
+        # persisted, and if so we don't persist it again.
+        #
+        # We should have checked this a long time before we get here, but it's
+        # possible that different send event requests race in such a way that
+        # they both pass the earlier checks. Checking here isn't racey as we can
+        # have only one `_persist_events` per room being called at a time.
+        replaced_events = await self.main_store.get_already_persisted_events(
+            (event for event, _ in events_and_contexts)
+        )
+
+        if replaced_events:
+            events_and_contexts = [
+                (e, ctx)
+                for e, ctx in events_and_contexts
+                if e.event_id not in replaced_events
+            ]
+
+            if not events_and_contexts:
+                return replaced_events
 
         chunks = [
             events_and_contexts[x : x + 100]
@@ -441,6 +509,8 @@ class EventsPersistenceStorage:
 
             await self._handle_potentially_left_users(potentially_left_users)
 
+        return replaced_events
+
     async def _calculate_new_extremities(
         self,
         room_id: str,