summary refs log tree commit diff
path: root/synapse/storage/databases/main/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/events.py')
-rw-r--r--synapse/storage/databases/main/events.py68
1 files changed, 54 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9c1e506da6..c229de48c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -127,6 +127,8 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+        self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
     @trace
     async def _persist_events_and_state_updates(
         self,
@@ -977,23 +979,43 @@ class PersistEventsStore:
     ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
-        to_insert = []
+        inserted_ts = self._clock.time_msec()
+        to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = []
+        to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = []
         for event, _ in events_and_contexts:
-            token_id = getattr(event.internal_metadata, "token_id", None)
             txn_id = getattr(event.internal_metadata, "txn_id", None)
-            if token_id and txn_id:
-                to_insert.append(
-                    (
-                        event.event_id,
-                        event.room_id,
-                        event.sender,
-                        token_id,
-                        txn_id,
-                        self._clock.time_msec(),
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            device_id = getattr(event.internal_metadata, "device_id", None)
+
+            if txn_id is not None:
+                if token_id is not None:
+                    to_insert_token_id.append(
+                        (
+                            event.event_id,
+                            event.room_id,
+                            event.sender,
+                            token_id,
+                            txn_id,
+                            inserted_ts,
+                        )
                     )
-                )
 
-        if to_insert:
+                if device_id is not None:
+                    to_insert_device_id.append(
+                        (
+                            event.event_id,
+                            event.room_id,
+                            event.sender,
+                            device_id,
+                            txn_id,
+                            inserted_ts,
+                        )
+                    )
+
+        # Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events.
+        # Since this is an experimental flag, we still store the mapping even if the
+        # flag is disabled.
+        if to_insert_token_id:
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_txn_id",
@@ -1005,7 +1027,25 @@ class PersistEventsStore:
                     "txn_id",
                     "inserted_ts",
                 ),
-                values=to_insert,
+                values=to_insert_token_id,
+            )
+
+        # With MSC3970, we rely on the device_id instead to scope the txn_id for events.
+        # We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970
+        # behaviour would allow for a UNIQUE constraint violation on this table
+        if to_insert_device_id and self._msc3970_enabled:
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="event_txn_id_device_id",
+                keys=(
+                    "event_id",
+                    "room_id",
+                    "user_id",
+                    "device_id",
+                    "txn_id",
+                    "inserted_ts",
+                ),
+                values=to_insert_device_id,
             )
 
     async def update_current_state(