summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/events.py31
-rw-r--r--synapse/storage/databases/main/events_worker.py83
-rw-r--r--synapse/storage/databases/main/registration.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19txn_id.sql40
-rw-r--r--synapse/storage/persist_events.py96
5 files changed, 241 insertions, 15 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b19c424ba9..fdb17745f6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -361,6 +361,8 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
+        self._persist_transaction_ids_txn(txn, events_and_contexts)
+
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
@@ -405,6 +407,35 @@ class PersistEventsStore:
         # room_memberships, where applicable.
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
+    def _persist_transaction_ids_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ):
+        """Persist the mapping from transaction IDs to event IDs (if defined).
+        """
+
+        to_insert = []
+        for event, _ in events_and_contexts:
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            txn_id = getattr(event.internal_metadata, "txn_id", None)
+            if token_id and txn_id:
+                to_insert.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "user_id": event.sender,
+                        "token_id": token_id,
+                        "txn_id": txn_id,
+                        "inserted_ts": self._clock.time_msec(),
+                    }
+                )
+
+        if to_insert:
+            self.db_pool.simple_insert_many_txn(
+                txn, table="event_txn_id", values=to_insert,
+            )
+
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4e74fafe43..3ec4d1d9c2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import itertools
 import logging
 import threading
@@ -137,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore):
                     db_conn, "events", "stream_ordering", step=-1
                 )
 
+        if not hs.config.worker.worker_app:
+            # We periodically clean out old transaction ID mappings
+            self._clock.looping_call(
+                run_as_background_process,
+                5 * 60 * 1000,
+                "_cleanup_old_transaction_ids",
+                self._cleanup_old_transaction_ids,
+            )
+
         self._get_event_cache = Cache(
             "*getEvent*",
             keylen=3,
@@ -1308,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore):
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
+
+    async def get_event_id_from_transaction_id(
+        self, room_id: str, user_id: str, token_id: int, txn_id: str
+    ) -> Optional[str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning the event ID if so.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="event_txn_id",
+            keyvalues={
+                "room_id": room_id,
+                "user_id": user_id,
+                "token_id": token_id,
+                "txn_id": txn_id,
+            },
+            retcol="event_id",
+            allow_none=True,
+            desc="get_event_id_from_transaction_id",
+        )
+
+    async def get_already_persisted_events(
+        self, events: Iterable[EventBase]
+    ) -> Dict[str, str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning a mapping from event ID in the given list to the event ID of
+        an existing event.
+
+        Also checks if there are duplicates in the given events, if there are
+        will map duplicates to the *first* event.
+        """
+
+        mapping = {}
+        txn_id_to_event = {}  # type: Dict[Tuple[str, int, str], str]
+
+        for event in events:
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            txn_id = getattr(event.internal_metadata, "txn_id", None)
+
+            if token_id and txn_id:
+                # Check if this is a duplicate of an event in the given events.
+                existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
+                if existing:
+                    mapping[event.event_id] = existing
+                    continue
+
+                # Check if this is a duplicate of an event we've already
+                # persisted.
+                existing = await self.get_event_id_from_transaction_id(
+                    event.room_id, event.sender, token_id, txn_id
+                )
+                if existing:
+                    mapping[event.event_id] = existing
+                    txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
+                else:
+                    txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
+
+        return mapping
+
+    async def _cleanup_old_transaction_ids(self):
+        """Cleans out transaction id mappings older than 24hrs.
+        """
+
+        def _cleanup_old_transaction_ids_txn(txn):
+            sql = """
+                DELETE FROM event_txn_id
+                WHERE inserted_ts < ?
+            """
+            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+            txn.execute(sql, (one_day_ago,))
+
+        return await self.db_pool.runInteraction(
+            "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+        )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 236d3cdbe3..9a003e30f9 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         token: str,
         device_id: Optional[str],
         valid_until_ms: Optional[int],
-    ) -> None:
+    ) -> int:
         """Adds an access token for the given user.
 
         Args:
@@ -1013,6 +1013,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             valid_until_ms: when the token is valid until. None for no expiry.
         Raises:
             StoreError if there was a problem adding this.
+        Returns:
+            The token ID
         """
         next_id = self._access_tokens_id_gen.get_next()
 
@@ -1028,6 +1030,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
+        return next_id
+
     def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn, "access_tokens", {"token": token}, "device_id"
diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
new file mode 100644
index 0000000000..b2454121a8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
@@ -0,0 +1,40 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A map of recent events persisted with transaction IDs. Used to deduplicate
+-- send event requests with the same transaction ID.
+--
+-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
+-- used to make the request.
+--
+-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
+-- events or access token we don't want to try and de-duplicate the event.
+CREATE TABLE IF NOT EXISTS event_txn_id (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 4d2d88d1f0..70e636b0ba 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -96,7 +96,9 @@ class _EventPeristenceQueue:
 
         Returns:
             defer.Deferred: a deferred which will resolve once the events are
-                persisted. Runs its callbacks *without* a logcontext.
+            persisted. Runs its callbacks *without* a logcontext. The result
+            is the same as that returned by the callback passed to
+            `handle_queue`.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
@@ -199,7 +201,7 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> RoomStreamToken:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """
         Write events to the database
         Args:
@@ -209,7 +211,11 @@ class EventsPersistenceStorage:
                 which might update the current state etc.
 
         Returns:
-            the stream ordering of the latest persisted event
+            List of events persisted, the current position room stream position.
+            The list of events persisted may not be the same as those passed in
+            if they were deduplicated due to an event already existing that
+            matched the transcation ID; the existing event is returned in such
+            a case.
         """
         partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
         for event, ctx in events_and_contexts:
@@ -225,19 +231,41 @@ class EventsPersistenceStorage:
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        await make_deferred_yieldable(
+        # Each deferred returns a map from event ID to existing event ID if the
+        # event was deduplicated. (The dict may also include other entries if
+        # the event was persisted in a batch with other events).
+        #
+        # Since we use `defer.gatherResults` we need to merge the returned list
+        # of dicts into one.
+        ret_vals = await make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
+        replaced_events = {}
+        for d in ret_vals:
+            replaced_events.update(d)
+
+        events = []
+        for event, _ in events_and_contexts:
+            existing_event_id = replaced_events.get(event.event_id)
+            if existing_event_id:
+                events.append(await self.main_store.get_event(existing_event_id))
+            else:
+                events.append(event)
 
-        return self.main_store.get_room_max_token()
+        return (
+            events,
+            self.main_store.get_room_max_token(),
+        )
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
+    ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
-            The stream ordering of `event`, and the stream ordering of the
-            latest persisted event
+            The event, stream ordering of `event`, and the stream ordering of the
+            latest persisted event. The returned event may not match the given
+            event if it was deduplicated due to an existing event matching the
+            transaction ID.
         """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)], backfilled=backfilled
@@ -245,19 +273,33 @@ class EventsPersistenceStorage:
 
         self._maybe_start_persisting(event.room_id)
 
-        await make_deferred_yieldable(deferred)
+        # The deferred returns a map from event ID to existing event ID if the
+        # event was deduplicated. (The dict may also include other entries if
+        # the event was persisted in a batch with other events.)
+        replaced_events = await make_deferred_yieldable(deferred)
+        replaced_event = replaced_events.get(event.event_id)
+        if replaced_event:
+            event = await self.main_store.get_event(replaced_event)
 
         event_stream_id = event.internal_metadata.stream_ordering
         # stream ordering should have been assigned by now
         assert event_stream_id
 
         pos = PersistedEventPosition(self._instance_name, event_stream_id)
-        return pos, self.main_store.get_room_max_token()
+        return event, pos, self.main_store.get_room_max_token()
 
     def _maybe_start_persisting(self, room_id: str):
+        """Pokes the `_event_persist_queue` to start handling new items in the
+        queue, if not already in progress.
+
+        Causes the deferreds returned by `add_to_queue` to resolve with: a
+        dictionary of event ID to event ID we didn't persist as we already had
+        another event persisted with the same TXN ID.
+        """
+
         async def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
-                await self._persist_events(
+                return await self._persist_events(
                     item.events_and_contexts, backfilled=item.backfilled
                 )
 
@@ -267,12 +309,38 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ):
+    ) -> Dict[str, str]:
         """Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
+
+        Returns:
+            A dictionary of event ID to event ID we didn't persist as we already
+            had another event persisted with the same TXN ID.
         """
+        replaced_events = {}  # type: Dict[str, str]
         if not events_and_contexts:
-            return
+            return replaced_events
+
+        # Check if any of the events have a transaction ID that has already been
+        # persisted, and if so we don't persist it again.
+        #
+        # We should have checked this a long time before we get here, but it's
+        # possible that different send event requests race in such a way that
+        # they both pass the earlier checks. Checking here isn't racey as we can
+        # have only one `_persist_events` per room being called at a time.
+        replaced_events = await self.main_store.get_already_persisted_events(
+            (event for event, _ in events_and_contexts)
+        )
+
+        if replaced_events:
+            events_and_contexts = [
+                (e, ctx)
+                for e, ctx in events_and_contexts
+                if e.event_id not in replaced_events
+            ]
+
+            if not events_and_contexts:
+                return replaced_events
 
         chunks = [
             events_and_contexts[x : x + 100]
@@ -441,6 +509,8 @@ class EventsPersistenceStorage:
 
             await self._handle_potentially_left_users(potentially_left_users)
 
+        return replaced_events
+
     async def _calculate_new_extremities(
         self,
         room_id: str,