summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2023-04-25 10:37:09 +0200
committerGitHub <noreply@github.com>2023-04-25 09:37:09 +0100
commit8b3a50299658a27175f55f1051e9470553c76d8e (patch)
tree902f659655a95e010ffc82dbd7ad6f07ecba82bb /synapse/storage/databases/main
parentFinish type hints for federation client HTTP code. (#15465) (diff)
downloadsynapse-8b3a50299658a27175f55f1051e9470553c76d8e.tar.xz
Experimental support for MSC3970: per-device transaction IDs (#15318)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/events.py68
-rw-r--r--synapse/storage/databases/main/events_worker.py33
2 files changed, 83 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9c1e506da6..c229de48c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -127,6 +127,8 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+        self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
     @trace
     async def _persist_events_and_state_updates(
         self,
@@ -977,23 +979,43 @@ class PersistEventsStore:
     ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
-        to_insert = []
+        inserted_ts = self._clock.time_msec()
+        to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = []
+        to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = []
         for event, _ in events_and_contexts:
-            token_id = getattr(event.internal_metadata, "token_id", None)
             txn_id = getattr(event.internal_metadata, "txn_id", None)
-            if token_id and txn_id:
-                to_insert.append(
-                    (
-                        event.event_id,
-                        event.room_id,
-                        event.sender,
-                        token_id,
-                        txn_id,
-                        self._clock.time_msec(),
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            device_id = getattr(event.internal_metadata, "device_id", None)
+
+            if txn_id is not None:
+                if token_id is not None:
+                    to_insert_token_id.append(
+                        (
+                            event.event_id,
+                            event.room_id,
+                            event.sender,
+                            token_id,
+                            txn_id,
+                            inserted_ts,
+                        )
                     )
-                )
 
-        if to_insert:
+                if device_id is not None:
+                    to_insert_device_id.append(
+                        (
+                            event.event_id,
+                            event.room_id,
+                            event.sender,
+                            device_id,
+                            txn_id,
+                            inserted_ts,
+                        )
+                    )
+
+        # Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events.
+        # Since this is an experimental flag, we still store the mapping even if the
+        # flag is disabled.
+        if to_insert_token_id:
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_txn_id",
@@ -1005,7 +1027,25 @@ class PersistEventsStore:
                     "txn_id",
                     "inserted_ts",
                 ),
-                values=to_insert,
+                values=to_insert_token_id,
+            )
+
+        # With MSC3970, we rely on the device_id instead to scope the txn_id for events.
+        # We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970
+        # behaviour would allow for a UNIQUE constraint violation on this table
+        if to_insert_device_id and self._msc3970_enabled:
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="event_txn_id_device_id",
+                keys=(
+                    "event_id",
+                    "room_id",
+                    "user_id",
+                    "device_id",
+                    "txn_id",
+                    "inserted_ts",
+                ),
+                values=to_insert_device_id,
             )
 
     async def update_current_state(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 0cf46626d2..0ff3fc7369 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2022,7 +2022,7 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
 
-    async def get_event_id_from_transaction_id(
+    async def get_event_id_from_transaction_id_and_token_id(
         self, room_id: str, user_id: str, token_id: int, txn_id: str
     ) -> Optional[str]:
         """Look up if we have already persisted an event for the transaction ID,
@@ -2038,7 +2038,26 @@ class EventsWorkerStore(SQLBaseStore):
             },
             retcol="event_id",
             allow_none=True,
-            desc="get_event_id_from_transaction_id",
+            desc="get_event_id_from_transaction_id_and_token_id",
+        )
+
+    async def get_event_id_from_transaction_id_and_device_id(
+        self, room_id: str, user_id: str, device_id: str, txn_id: str
+    ) -> Optional[str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning the event ID if so.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="event_txn_id_device_id",
+            keyvalues={
+                "room_id": room_id,
+                "user_id": user_id,
+                "device_id": device_id,
+                "txn_id": txn_id,
+            },
+            retcol="event_id",
+            allow_none=True,
+            desc="get_event_id_from_transaction_id_and_device_id",
         )
 
     async def get_already_persisted_events(
@@ -2068,7 +2087,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # Check if this is a duplicate of an event we've already
                 # persisted.
-                existing = await self.get_event_id_from_transaction_id(
+                existing = await self.get_event_id_from_transaction_id_and_token_id(
                     event.room_id, event.sender, token_id, txn_id
                 )
                 if existing:
@@ -2084,11 +2103,17 @@ class EventsWorkerStore(SQLBaseStore):
         """Cleans out transaction id mappings older than 24hrs."""
 
         def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
+            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
             """
-            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+            txn.execute(sql, (one_day_ago,))
+
+            sql = """
+                DELETE FROM event_txn_id_device_id
+                WHERE inserted_ts < ?
+            """
             txn.execute(sql, (one_day_ago,))
 
         return await self.db_pool.runInteraction(