summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15318.feature1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/events/__init__.py9
-rw-r--r--synapse/events/utils.py58
-rw-r--r--synapse/handlers/message.py38
-rw-r--r--synapse/handlers/room_member.py33
-rw-r--r--synapse/rest/client/transactions.py13
-rw-r--r--synapse/server.py4
-rw-r--r--synapse/storage/databases/main/events.py68
-rw-r--r--synapse/storage/databases/main/events_worker.py33
-rw-r--r--synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql53
11 files changed, 265 insertions, 48 deletions
diff --git a/changelog.d/15318.feature b/changelog.d/15318.feature
new file mode 100644
index 0000000000..47bb2e17a7
--- /dev/null
+++ b/changelog.d/15318.feature
@@ -0,0 +1 @@
+Experimental support for MSC3970: Scope transaction IDs to devices.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 7687c80ea0..6599679731 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -191,3 +191,6 @@ class ExperimentalConfig(Config):
 
         # MSC2659: Application service ping endpoint
         self.msc2659_enabled = experimental.get("msc2659_enabled", False)
+
+        # MSC3970: Scope transaction IDs to devices
+        self.msc3970_enabled = experimental.get("msc3970_enabled", False)
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 4501518cf0..de7e5be42b 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -198,9 +198,16 @@ class _EventInternalMetadata:
     soft_failed: DictProperty[bool] = DictProperty("soft_failed")
     proactively_send: DictProperty[bool] = DictProperty("proactively_send")
     redacted: DictProperty[bool] = DictProperty("redacted")
+    historical: DictProperty[bool] = DictProperty("historical")
+
     txn_id: DictProperty[str] = DictProperty("txn_id")
+    """The transaction ID, if it was set when the event was created."""
+
     token_id: DictProperty[int] = DictProperty("token_id")
-    historical: DictProperty[bool] = DictProperty("historical")
+    """The access token ID of the user who sent this event, if any."""
+
+    device_id: DictProperty[str] = DictProperty("device_id")
+    """The device ID of the user who sent this event, if any."""
 
     # XXX: These are set by StreamWorkerStore._set_before_and_after.
     # I'm pretty sure that these are never persisted to the database, so shouldn't
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 1d5d7491cd..0802eb1963 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -339,6 +339,7 @@ def serialize_event(
     time_now_ms: int,
     *,
     config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
+    msc3970_enabled: bool = False,
 ) -> JsonDict:
     """Serialize event for clients
 
@@ -346,6 +347,8 @@ def serialize_event(
         e
         time_now_ms
         config: Event serialization config
+        msc3970_enabled: Whether MSC3970 is enabled. It changes whether we should
+            include the `transaction_id` in the event's `unsigned` section.
 
     Returns:
         The serialized event dictionary.
@@ -368,27 +371,43 @@ def serialize_event(
 
     if "redacted_because" in e.unsigned:
         d["unsigned"]["redacted_because"] = serialize_event(
-            e.unsigned["redacted_because"], time_now_ms, config=config
+            e.unsigned["redacted_because"],
+            time_now_ms,
+            config=config,
+            msc3970_enabled=msc3970_enabled,
         )
 
     # If we have a txn_id saved in the internal_metadata, we should include it in the
     # unsigned section of the event if it was sent by the same session as the one
     # requesting the event.
-    # There is a special case for guests, because they only have one access token
-    # without associated access_token_id, so we always include the txn_id for events
-    # they sent.
-    txn_id = getattr(e.internal_metadata, "txn_id", None)
+    txn_id: Optional[str] = getattr(e.internal_metadata, "txn_id", None)
     if txn_id is not None and config.requester is not None:
-        event_token_id = getattr(e.internal_metadata, "token_id", None)
-        if config.requester.user.to_string() == e.sender and (
-            (
-                event_token_id is not None
-                and config.requester.access_token_id is not None
-                and event_token_id == config.requester.access_token_id
+        # For the MSC3970 rules to be applied, we *need* to have the device ID in the
+        # event internal metadata. Since we were not recording them before, if it hasn't
+        # been recorded, we fallback to the old behaviour.
+        event_device_id: Optional[str] = getattr(e.internal_metadata, "device_id", None)
+        if msc3970_enabled and event_device_id is not None:
+            if event_device_id == config.requester.device_id:
+                d["unsigned"]["transaction_id"] = txn_id
+
+        else:
+            # The pre-MSC3970 behaviour is to only include the transaction ID if the
+            # event was sent from the same access token. For regular users, we can use
+            # the access token ID to determine this. For guests, we can't, but since
+            # each guest only has one access token, we can just check that the event was
+            # sent by the same user as the one requesting the event.
+            event_token_id: Optional[int] = getattr(
+                e.internal_metadata, "token_id", None
             )
-            or config.requester.is_guest
-        ):
-            d["unsigned"]["transaction_id"] = txn_id
+            if config.requester.user.to_string() == e.sender and (
+                (
+                    event_token_id is not None
+                    and config.requester.access_token_id is not None
+                    and event_token_id == config.requester.access_token_id
+                )
+                or config.requester.is_guest
+            ):
+                d["unsigned"]["transaction_id"] = txn_id
 
     # invite_room_state and knock_room_state are a list of stripped room state events
     # that are meant to provide metadata about a room to an invitee/knocker. They are
@@ -419,6 +438,9 @@ class EventClientSerializer:
     clients.
     """
 
+    def __init__(self, *, msc3970_enabled: bool = False):
+        self._msc3970_enabled = msc3970_enabled
+
     def serialize_event(
         self,
         event: Union[JsonDict, EventBase],
@@ -443,7 +465,9 @@ class EventClientSerializer:
         if not isinstance(event, EventBase):
             return event
 
-        serialized_event = serialize_event(event, time_now, config=config)
+        serialized_event = serialize_event(
+            event, time_now, config=config, msc3970_enabled=self._msc3970_enabled
+        )
 
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
@@ -501,7 +525,9 @@ class EventClientSerializer:
             # `sender` of the edit; however MSC3925 proposes extending it to the whole
             # of the edit, which is what we do here.
             serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
-                event_aggregations.replace, time_now, config=config
+                event_aggregations.replace,
+                time_now,
+                config=config,
             )
 
         # Include any threaded replies to this event.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 2e964ed37e..ac1932a7f9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -561,6 +561,8 @@ class EventCreationHandler:
                 expiry_ms=30 * 60 * 1000,
             )
 
+        self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
     async def create_event(
         self,
         requester: Requester,
@@ -701,9 +703,16 @@ class EventCreationHandler:
         if require_consent and not is_exempt:
             await self.assert_accepted_privacy_policy(requester)
 
+        # Save the access token ID, the device ID and the transaction ID in the event
+        # internal metadata. This is useful to determine if we should echo the
+        # transaction_id in events.
+        # See `synapse.events.utils.EventClientSerializer.serialize_event`
         if requester.access_token_id is not None:
             builder.internal_metadata.token_id = requester.access_token_id
 
+        if requester.device_id is not None:
+            builder.internal_metadata.device_id = requester.device_id
+
         if txn_id is not None:
             builder.internal_metadata.txn_id = txn_id
 
@@ -897,12 +906,31 @@ class EventCreationHandler:
         Returns:
             An event if one could be found, None otherwise.
         """
+
+        if self._msc3970_enabled and requester.device_id:
+            # When MSC3970 is enabled, we lookup for events sent by the same device first,
+            # and fallback to the old behaviour if none were found.
+            existing_event_id = (
+                await self.store.get_event_id_from_transaction_id_and_device_id(
+                    room_id,
+                    requester.user.to_string(),
+                    requester.device_id,
+                    txn_id,
+                )
+            )
+            if existing_event_id:
+                return await self.store.get_event(existing_event_id)
+
+        # Pre-MSC3970, we looked up for events that were sent by the same session by
+        # using the access token ID.
         if requester.access_token_id:
-            existing_event_id = await self.store.get_event_id_from_transaction_id(
-                room_id,
-                requester.user.to_string(),
-                requester.access_token_id,
-                txn_id,
+            existing_event_id = (
+                await self.store.get_event_id_from_transaction_id_and_token_id(
+                    room_id,
+                    requester.user.to_string(),
+                    requester.access_token_id,
+                    txn_id,
+                )
             )
             if existing_event_id:
                 return await self.store.get_event(existing_event_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ec317e6023..ed805d6ec8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -169,6 +169,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.request_ratelimiter = hs.get_request_ratelimiter()
         hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
 
+        self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
     def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
         """Notify the rate limiter that a room join has occurred.
 
@@ -399,13 +401,30 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # Check if we already have an event with a matching transaction ID. (We
         # do this check just before we persist an event as well, but may as well
         # do it up front for efficiency.)
-        if txn_id and requester.access_token_id:
-            existing_event_id = await self.store.get_event_id_from_transaction_id(
-                room_id,
-                requester.user.to_string(),
-                requester.access_token_id,
-                txn_id,
-            )
+        if txn_id:
+            existing_event_id = None
+            if self._msc3970_enabled and requester.device_id:
+                # When MSC3970 is enabled, we lookup for events sent by the same device
+                # first, and fallback to the old behaviour if none were found.
+                existing_event_id = (
+                    await self.store.get_event_id_from_transaction_id_and_device_id(
+                        room_id,
+                        requester.user.to_string(),
+                        requester.device_id,
+                        txn_id,
+                    )
+                )
+
+            if requester.access_token_id and not existing_event_id:
+                existing_event_id = (
+                    await self.store.get_event_id_from_transaction_id_and_token_id(
+                        room_id,
+                        requester.user.to_string(),
+                        requester.access_token_id,
+                        txn_id,
+                    )
+                )
+
             if existing_event_id:
                 event_pos = await self.store.get_position_for_event(existing_event_id)
                 return existing_event_id, event_pos.stream
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index f2aaab6227..0d8a63d8be 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -50,6 +50,8 @@ class HttpTransactionCache:
         # for at *LEAST* 30 mins, and at *MOST* 60 mins.
         self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
 
+        self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
     def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
         """A helper function which returns a transaction key that can be used
         with TransactionCache for idempotent requests.
@@ -58,6 +60,7 @@ class HttpTransactionCache:
         requests to the same endpoint. The key is formed from the HTTP request
         path and attributes from the requester: the access_token_id for regular users,
         the user ID for guest users, and the appservice ID for appservice users.
+        With MSC3970, for regular users, the key is based on the user ID and device ID.
 
         Args:
             request: The incoming request.
@@ -67,11 +70,21 @@ class HttpTransactionCache:
         """
         assert request.path is not None
         path: str = request.path.decode("utf8")
+
         if requester.is_guest:
             assert requester.user is not None, "Guest requester must have a user ID set"
             return (path, "guest", requester.user)
+
         elif requester.app_service is not None:
             return (path, "appservice", requester.app_service.id)
+
+        # With MSC3970, we use the user ID and device ID as the transaction key
+        elif self._msc3970_enabled:
+            assert requester.user, "Requester must have a user"
+            assert requester.device_id, "Requester must have a device_id"
+            return (path, "user", requester.user, requester.device_id)
+
+        # Otherwise, the pre-MSC3970 behaviour is to use the access token ID
         else:
             assert (
                 requester.access_token_id is not None
diff --git a/synapse/server.py b/synapse/server.py
index 559724594b..08ad97b952 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -762,7 +762,9 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_event_client_serializer(self) -> EventClientSerializer:
-        return EventClientSerializer()
+        return EventClientSerializer(
+            msc3970_enabled=self.config.experimental.msc3970_enabled
+        )
 
     @cache_in_self
     def get_password_policy_handler(self) -> PasswordPolicyHandler:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9c1e506da6..c229de48c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -127,6 +127,8 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+        self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
     @trace
     async def _persist_events_and_state_updates(
         self,
@@ -977,23 +979,43 @@ class PersistEventsStore:
     ) -> None:
         """Persist the mapping from transaction IDs to event IDs (if defined)."""
 
-        to_insert = []
+        inserted_ts = self._clock.time_msec()
+        to_insert_token_id: List[Tuple[str, str, str, int, str, int]] = []
+        to_insert_device_id: List[Tuple[str, str, str, str, str, int]] = []
         for event, _ in events_and_contexts:
-            token_id = getattr(event.internal_metadata, "token_id", None)
             txn_id = getattr(event.internal_metadata, "txn_id", None)
-            if token_id and txn_id:
-                to_insert.append(
-                    (
-                        event.event_id,
-                        event.room_id,
-                        event.sender,
-                        token_id,
-                        txn_id,
-                        self._clock.time_msec(),
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            device_id = getattr(event.internal_metadata, "device_id", None)
+
+            if txn_id is not None:
+                if token_id is not None:
+                    to_insert_token_id.append(
+                        (
+                            event.event_id,
+                            event.room_id,
+                            event.sender,
+                            token_id,
+                            txn_id,
+                            inserted_ts,
+                        )
                     )
-                )
 
-        if to_insert:
+                if device_id is not None:
+                    to_insert_device_id.append(
+                        (
+                            event.event_id,
+                            event.room_id,
+                            event.sender,
+                            device_id,
+                            txn_id,
+                            inserted_ts,
+                        )
+                    )
+
+        # Pre-MSC3970, we rely on the access_token_id to scope the txn_id for events.
+        # Since this is an experimental flag, we still store the mapping even if the
+        # flag is disabled.
+        if to_insert_token_id:
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="event_txn_id",
@@ -1005,7 +1027,25 @@ class PersistEventsStore:
                     "txn_id",
                     "inserted_ts",
                 ),
-                values=to_insert,
+                values=to_insert_token_id,
+            )
+
+        # With MSC3970, we rely on the device_id instead to scope the txn_id for events.
+        # We're only inserting if MSC3970 is *enabled*, because else the pre-MSC3970
+        # behaviour would allow for a UNIQUE constraint violation on this table
+        if to_insert_device_id and self._msc3970_enabled:
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="event_txn_id_device_id",
+                keys=(
+                    "event_id",
+                    "room_id",
+                    "user_id",
+                    "device_id",
+                    "txn_id",
+                    "inserted_ts",
+                ),
+                values=to_insert_device_id,
             )
 
     async def update_current_state(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 0cf46626d2..0ff3fc7369 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2022,7 +2022,7 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
 
-    async def get_event_id_from_transaction_id(
+    async def get_event_id_from_transaction_id_and_token_id(
         self, room_id: str, user_id: str, token_id: int, txn_id: str
     ) -> Optional[str]:
         """Look up if we have already persisted an event for the transaction ID,
@@ -2038,7 +2038,26 @@ class EventsWorkerStore(SQLBaseStore):
             },
             retcol="event_id",
             allow_none=True,
-            desc="get_event_id_from_transaction_id",
+            desc="get_event_id_from_transaction_id_and_token_id",
+        )
+
+    async def get_event_id_from_transaction_id_and_device_id(
+        self, room_id: str, user_id: str, device_id: str, txn_id: str
+    ) -> Optional[str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning the event ID if so.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="event_txn_id_device_id",
+            keyvalues={
+                "room_id": room_id,
+                "user_id": user_id,
+                "device_id": device_id,
+                "txn_id": txn_id,
+            },
+            retcol="event_id",
+            allow_none=True,
+            desc="get_event_id_from_transaction_id_and_device_id",
         )
 
     async def get_already_persisted_events(
@@ -2068,7 +2087,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # Check if this is a duplicate of an event we've already
                 # persisted.
-                existing = await self.get_event_id_from_transaction_id(
+                existing = await self.get_event_id_from_transaction_id_and_token_id(
                     event.room_id, event.sender, token_id, txn_id
                 )
                 if existing:
@@ -2084,11 +2103,17 @@ class EventsWorkerStore(SQLBaseStore):
         """Cleans out transaction id mappings older than 24hrs."""
 
         def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
+            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
             sql = """
                 DELETE FROM event_txn_id
                 WHERE inserted_ts < ?
             """
-            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+            txn.execute(sql, (one_day_ago,))
+
+            sql = """
+                DELETE FROM event_txn_id_device_id
+                WHERE inserted_ts < ?
+            """
             txn.execute(sql, (one_day_ago,))
 
         return await self.db_pool.runInteraction(
diff --git a/synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql b/synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql
new file mode 100644
index 0000000000..517a821a56
--- /dev/null
+++ b/synapse/storage/schema/main/delta/74/05_events_txn_id_device_id.sql
@@ -0,0 +1,53 @@
+/* Copyright 2023 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- For MSC3970, in addition to the (room_id, user_id, token_id, txn_id) -> event_id mapping for each local event,
+-- we also store the (room_id, user_id, device_id, txn_id) -> event_id mapping.
+--
+-- This adds a new event_txn_id_device_id table.
+
+-- A map of recent events persisted with transaction IDs. Used to deduplicate
+-- send event requests with the same transaction ID.
+--
+-- Note: with MSC3970, transaction IDs are scoped to the 
+-- room ID/user ID/device ID that was used to make the request.
+--
+-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
+-- event or device we don't want to try and de-duplicate the event.
+CREATE TABLE IF NOT EXISTS event_txn_id_device_id (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    device_id TEXT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (user_id, device_id)
+        REFERENCES devices (user_id, device_id) ON DELETE CASCADE
+);
+
+-- This ensures that there is only one mapping per event_id.
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_device_id_event_id
+    ON event_txn_id_device_id(event_id);
+
+-- This ensures that there is only one mapping per (room_id, user_id, device_id, txn_id) tuple.
+-- Events are usually looked up using this index.
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_device_id_txn_id 
+    ON event_txn_id_device_id(room_id, user_id, device_id, txn_id);
+
+-- This table is cleaned up regularly, removing the oldest entries, hence this index.
+CREATE INDEX IF NOT EXISTS event_txn_id_device_id_ts
+    ON event_txn_id_device_id(inserted_ts);