summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-10-13 12:07:56 +0100
committerGitHub <noreply@github.com>2020-10-13 12:07:56 +0100
commitb2486f6656bec2307e62de19d2830994a42b879d (patch)
tree630ab6e5b341a53b5f6c1f7e966ab79673eb73b9
parentMerge branch 'master' into develop (diff)
downloadsynapse-b2486f6656bec2307e62de19d2830994a42b879d.tar.xz
Fix message duplication if something goes wrong after persisting the event (#8476)
Should fix #3365.
-rw-r--r--changelog.d/8476.bugfix1
-rw-r--r--synapse/handlers/federation.py9
-rw-r--r--synapse/handlers/message.py48
-rw-r--r--synapse/handlers/room_member.py13
-rw-r--r--synapse/replication/http/send_event.py16
-rw-r--r--synapse/storage/databases/main/events.py31
-rw-r--r--synapse/storage/databases/main/events_worker.py83
-rw-r--r--synapse/storage/databases/main/registration.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19txn_id.sql40
-rw-r--r--synapse/storage/persist_events.py96
-rw-r--r--tests/handlers/test_message.py157
-rw-r--r--tests/rest/client/test_third_party_rules.py2
-rw-r--r--tests/unittest.py11
13 files changed, 481 insertions, 32 deletions
diff --git a/changelog.d/8476.bugfix b/changelog.d/8476.bugfix
new file mode 100644
index 0000000000..993a269979
--- /dev/null
+++ b/changelog.d/8476.bugfix
@@ -0,0 +1 @@
+Fix message duplication if something goes wrong after persisting the event.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5ac2fc5656..455acd7669 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2966,17 +2966,20 @@ class FederationHandler(BaseHandler):
             return result["max_stream_id"]
         else:
             assert self.storage.persistence
-            max_stream_token = await self.storage.persistence.persist_events(
+
+            # Note that this returns the events that were persisted, which may not be
+            # the same as were passed in if some were deduplicated due to transaction IDs.
+            events, max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
             if self._ephemeral_messages_enabled:
-                for (event, context) in event_and_contexts:
+                for event in events:
                     # If there's an expiry timestamp on the event, schedule its expiry.
                     self._message_handler.maybe_schedule_expiry(event)
 
             if not backfilled:  # Never notify for backfilled events
-                for event, _ in event_and_contexts:
+                for event in events:
                     await self._notify_persisted_event(event, max_stream_token)
 
             return max_stream_token.stream
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ad0b7bd868..b0da938aa9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -689,7 +689,7 @@ class EventCreationHandler:
                 send this event.
 
         Returns:
-            The event, and its stream ordering (if state event deduplication happened,
+            The event, and its stream ordering (if deduplication happened,
             the previous, duplicate event).
 
         Raises:
@@ -712,6 +712,19 @@ class EventCreationHandler:
         # extremities to pile up, which in turn leads to state resolution
         # taking longer.
         with (await self.limiter.queue(event_dict["room_id"])):
+            if txn_id and requester.access_token_id:
+                existing_event_id = await self.store.get_event_id_from_transaction_id(
+                    event_dict["room_id"],
+                    requester.user.to_string(),
+                    requester.access_token_id,
+                    txn_id,
+                )
+                if existing_event_id:
+                    event = await self.store.get_event(existing_event_id)
+                    # we know it was persisted, so must have a stream ordering
+                    assert event.internal_metadata.stream_ordering
+                    return event, event.internal_metadata.stream_ordering
+
             event, context = await self.create_event(
                 requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
             )
@@ -913,10 +926,20 @@ class EventCreationHandler:
                     extra_users=extra_users,
                 )
                 stream_id = result["stream_id"]
-                event.internal_metadata.stream_ordering = stream_id
+                event_id = result["event_id"]
+                if event_id != event.event_id:
+                    # If we get a different event back then it means that its
+                    # been de-duplicated, so we replace the given event with the
+                    # one already persisted.
+                    event = await self.store.get_event(event_id)
+                else:
+                    # If we newly persisted the event then we need to update its
+                    # stream_ordering entry manually (as it was persisted on
+                    # another worker).
+                    event.internal_metadata.stream_ordering = stream_id
                 return event
 
-            stream_id = await self.persist_and_notify_client_event(
+            event = await self.persist_and_notify_client_event(
                 requester, event, context, ratelimit=ratelimit, extra_users=extra_users
             )
 
@@ -965,11 +988,16 @@ class EventCreationHandler:
         context: EventContext,
         ratelimit: bool = True,
         extra_users: List[UserID] = [],
-    ) -> int:
+    ) -> EventBase:
         """Called when we have fully built the event, have already
         calculated the push actions for the event, and checked auth.
 
         This should only be run on the instance in charge of persisting events.
+
+        Returns:
+            The persisted event. This may be different than the given event if
+            it was de-duplicated (e.g. because we had already persisted an
+            event with the same transaction ID.)
         """
         assert self.storage.persistence is not None
         assert self._events_shard_config.should_handle(
@@ -1137,9 +1165,13 @@ class EventCreationHandler:
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        event_pos, max_stream_token = await self.storage.persistence.persist_event(
-            event, context=context
-        )
+        # Note that this returns the event that was persisted, which may not be
+        # the same as we passed in if it was deduplicated due transaction IDs.
+        (
+            event,
+            event_pos,
+            max_stream_token,
+        ) = await self.storage.persistence.persist_event(event, context=context)
 
         if self._ephemeral_events_enabled:
             # If there's an expiry timestamp on the event, schedule its expiry.
@@ -1160,7 +1192,7 @@ class EventCreationHandler:
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-        return event_pos.stream
+        return event
 
     async def _bump_active_time(self, user: UserID) -> None:
         try:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index ffbc62ff44..0080eeaf8d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -171,6 +171,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if requester.is_guest:
             content["kind"] = "guest"
 
+        # Check if we already have an event with a matching transaction ID. (We
+        # do this check just before we persist an event as well, but may as well
+        # do it up front for efficiency.)
+        if txn_id and requester.access_token_id:
+            existing_event_id = await self.store.get_event_id_from_transaction_id(
+                room_id, requester.user.to_string(), requester.access_token_id, txn_id,
+            )
+            if existing_event_id:
+                event_pos = await self.store.get_position_for_event(existing_event_id)
+                return existing_event_id, event_pos.stream
+
         event, context = await self.event_creation_handler.create_event(
             requester,
             {
@@ -679,7 +690,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if is_blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
-        await self.event_creation_handler.handle_new_client_event(
+        event = await self.event_creation_handler.handle_new_client_event(
             requester, event, context, extra_users=[target_user], ratelimit=ratelimit
         )
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 9a3a694d5d..fc129dbaa7 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -46,6 +46,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "ratelimit": true,
             "extra_users": [],
         }
+
+        200 OK
+
+        { "stream_id": 12345, "event_id": "$abcdef..." }
+
+    The returned event ID may not match the sent event if it was deduplicated.
     """
 
     NAME = "send_event"
@@ -116,11 +122,17 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
         )
 
-        stream_id = await self.event_creation_handler.persist_and_notify_client_event(
+        event = await self.event_creation_handler.persist_and_notify_client_event(
             requester, event, context, ratelimit=ratelimit, extra_users=extra_users
         )
 
-        return 200, {"stream_id": stream_id}
+        return (
+            200,
+            {
+                "stream_id": event.internal_metadata.stream_ordering,
+                "event_id": event.event_id,
+            },
+        )
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b19c424ba9..fdb17745f6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -361,6 +361,8 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
+        self._persist_transaction_ids_txn(txn, events_and_contexts)
+
         # Insert into event_to_state_groups.
         self._store_event_state_mappings_txn(txn, events_and_contexts)
 
@@ -405,6 +407,35 @@ class PersistEventsStore:
         # room_memberships, where applicable.
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
+    def _persist_transaction_ids_txn(
+        self,
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+    ):
+        """Persist the mapping from transaction IDs to event IDs (if defined).
+        """
+
+        to_insert = []
+        for event, _ in events_and_contexts:
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            txn_id = getattr(event.internal_metadata, "txn_id", None)
+            if token_id and txn_id:
+                to_insert.append(
+                    {
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "user_id": event.sender,
+                        "token_id": token_id,
+                        "txn_id": txn_id,
+                        "inserted_ts": self._clock.time_msec(),
+                    }
+                )
+
+        if to_insert:
+            self.db_pool.simple_insert_many_txn(
+                txn, table="event_txn_id", values=to_insert,
+            )
+
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 4e74fafe43..3ec4d1d9c2 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import itertools
 import logging
 import threading
@@ -137,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore):
                     db_conn, "events", "stream_ordering", step=-1
                 )
 
+        if not hs.config.worker.worker_app:
+            # We periodically clean out old transaction ID mappings
+            self._clock.looping_call(
+                run_as_background_process,
+                5 * 60 * 1000,
+                "_cleanup_old_transaction_ids",
+                self._cleanup_old_transaction_ids,
+            )
+
         self._get_event_cache = Cache(
             "*getEvent*",
             keylen=3,
@@ -1308,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore):
         return await self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
+
+    async def get_event_id_from_transaction_id(
+        self, room_id: str, user_id: str, token_id: int, txn_id: str
+    ) -> Optional[str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning the event ID if so.
+        """
+        return await self.db_pool.simple_select_one_onecol(
+            table="event_txn_id",
+            keyvalues={
+                "room_id": room_id,
+                "user_id": user_id,
+                "token_id": token_id,
+                "txn_id": txn_id,
+            },
+            retcol="event_id",
+            allow_none=True,
+            desc="get_event_id_from_transaction_id",
+        )
+
+    async def get_already_persisted_events(
+        self, events: Iterable[EventBase]
+    ) -> Dict[str, str]:
+        """Look up if we have already persisted an event for the transaction ID,
+        returning a mapping from event ID in the given list to the event ID of
+        an existing event.
+
+        Also checks if there are duplicates in the given events, if there are
+        will map duplicates to the *first* event.
+        """
+
+        mapping = {}
+        txn_id_to_event = {}  # type: Dict[Tuple[str, int, str], str]
+
+        for event in events:
+            token_id = getattr(event.internal_metadata, "token_id", None)
+            txn_id = getattr(event.internal_metadata, "txn_id", None)
+
+            if token_id and txn_id:
+                # Check if this is a duplicate of an event in the given events.
+                existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
+                if existing:
+                    mapping[event.event_id] = existing
+                    continue
+
+                # Check if this is a duplicate of an event we've already
+                # persisted.
+                existing = await self.get_event_id_from_transaction_id(
+                    event.room_id, event.sender, token_id, txn_id
+                )
+                if existing:
+                    mapping[event.event_id] = existing
+                    txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
+                else:
+                    txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
+
+        return mapping
+
+    async def _cleanup_old_transaction_ids(self):
+        """Cleans out transaction id mappings older than 24hrs.
+        """
+
+        def _cleanup_old_transaction_ids_txn(txn):
+            sql = """
+                DELETE FROM event_txn_id
+                WHERE inserted_ts < ?
+            """
+            one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
+            txn.execute(sql, (one_day_ago,))
+
+        return await self.db_pool.runInteraction(
+            "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
+        )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 236d3cdbe3..9a003e30f9 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         token: str,
         device_id: Optional[str],
         valid_until_ms: Optional[int],
-    ) -> None:
+    ) -> int:
         """Adds an access token for the given user.
 
         Args:
@@ -1013,6 +1013,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             valid_until_ms: when the token is valid until. None for no expiry.
         Raises:
             StoreError if there was a problem adding this.
+        Returns:
+            The token ID
         """
         next_id = self._access_tokens_id_gen.get_next()
 
@@ -1028,6 +1030,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
+        return next_id
+
     def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn, "access_tokens", {"token": token}, "device_id"
diff --git a/synapse/storage/databases/main/schema/delta/58/19txn_id.sql b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
new file mode 100644
index 0000000000..b2454121a8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19txn_id.sql
@@ -0,0 +1,40 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A map of recent events persisted with transaction IDs. Used to deduplicate
+-- send event requests with the same transaction ID.
+--
+-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
+-- used to make the request.
+--
+-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
+-- events or access token we don't want to try and de-duplicate the event.
+CREATE TABLE IF NOT EXISTS event_txn_id (
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    user_id TEXT NOT NULL,
+    token_id BIGINT NOT NULL,
+    txn_id TEXT NOT NULL,
+    inserted_ts BIGINT NOT NULL,
+    FOREIGN KEY (event_id)
+        REFERENCES events (event_id) ON DELETE CASCADE,
+    FOREIGN KEY (token_id)
+        REFERENCES access_tokens (id) ON DELETE CASCADE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
+CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
+CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 4d2d88d1f0..70e636b0ba 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -96,7 +96,9 @@ class _EventPeristenceQueue:
 
         Returns:
             defer.Deferred: a deferred which will resolve once the events are
-                persisted. Runs its callbacks *without* a logcontext.
+            persisted. Runs its callbacks *without* a logcontext. The result
+            is the same as that returned by the callback passed to
+            `handle_queue`.
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
         if queue:
@@ -199,7 +201,7 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> RoomStreamToken:
+    ) -> Tuple[List[EventBase], RoomStreamToken]:
         """
         Write events to the database
         Args:
@@ -209,7 +211,11 @@ class EventsPersistenceStorage:
                 which might update the current state etc.
 
         Returns:
-            the stream ordering of the latest persisted event
+            List of events persisted, the current position room stream position.
+            The list of events persisted may not be the same as those passed in
+            if they were deduplicated due to an event already existing that
+            matched the transcation ID; the existing event is returned in such
+            a case.
         """
         partitioned = {}  # type: Dict[str, List[Tuple[EventBase, EventContext]]]
         for event, ctx in events_and_contexts:
@@ -225,19 +231,41 @@ class EventsPersistenceStorage:
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        await make_deferred_yieldable(
+        # Each deferred returns a map from event ID to existing event ID if the
+        # event was deduplicated. (The dict may also include other entries if
+        # the event was persisted in a batch with other events).
+        #
+        # Since we use `defer.gatherResults` we need to merge the returned list
+        # of dicts into one.
+        ret_vals = await make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
+        replaced_events = {}
+        for d in ret_vals:
+            replaced_events.update(d)
+
+        events = []
+        for event, _ in events_and_contexts:
+            existing_event_id = replaced_events.get(event.event_id)
+            if existing_event_id:
+                events.append(await self.main_store.get_event(existing_event_id))
+            else:
+                events.append(event)
 
-        return self.main_store.get_room_max_token()
+        return (
+            events,
+            self.main_store.get_room_max_token(),
+        )
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
+    ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
-            The stream ordering of `event`, and the stream ordering of the
-            latest persisted event
+            The event, stream ordering of `event`, and the stream ordering of the
+            latest persisted event. The returned event may not match the given
+            event if it was deduplicated due to an existing event matching the
+            transaction ID.
         """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)], backfilled=backfilled
@@ -245,19 +273,33 @@ class EventsPersistenceStorage:
 
         self._maybe_start_persisting(event.room_id)
 
-        await make_deferred_yieldable(deferred)
+        # The deferred returns a map from event ID to existing event ID if the
+        # event was deduplicated. (The dict may also include other entries if
+        # the event was persisted in a batch with other events.)
+        replaced_events = await make_deferred_yieldable(deferred)
+        replaced_event = replaced_events.get(event.event_id)
+        if replaced_event:
+            event = await self.main_store.get_event(replaced_event)
 
         event_stream_id = event.internal_metadata.stream_ordering
         # stream ordering should have been assigned by now
         assert event_stream_id
 
         pos = PersistedEventPosition(self._instance_name, event_stream_id)
-        return pos, self.main_store.get_room_max_token()
+        return event, pos, self.main_store.get_room_max_token()
 
     def _maybe_start_persisting(self, room_id: str):
+        """Pokes the `_event_persist_queue` to start handling new items in the
+        queue, if not already in progress.
+
+        Causes the deferreds returned by `add_to_queue` to resolve with: a
+        dictionary of event ID to event ID we didn't persist as we already had
+        another event persisted with the same TXN ID.
+        """
+
         async def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
-                await self._persist_events(
+                return await self._persist_events(
                     item.events_and_contexts, backfilled=item.backfilled
                 )
 
@@ -267,12 +309,38 @@ class EventsPersistenceStorage:
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ):
+    ) -> Dict[str, str]:
         """Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
+
+        Returns:
+            A dictionary of event ID to event ID we didn't persist as we already
+            had another event persisted with the same TXN ID.
         """
+        replaced_events = {}  # type: Dict[str, str]
         if not events_and_contexts:
-            return
+            return replaced_events
+
+        # Check if any of the events have a transaction ID that has already been
+        # persisted, and if so we don't persist it again.
+        #
+        # We should have checked this a long time before we get here, but it's
+        # possible that different send event requests race in such a way that
+        # they both pass the earlier checks. Checking here isn't racey as we can
+        # have only one `_persist_events` per room being called at a time.
+        replaced_events = await self.main_store.get_already_persisted_events(
+            (event for event, _ in events_and_contexts)
+        )
+
+        if replaced_events:
+            events_and_contexts = [
+                (e, ctx)
+                for e, ctx in events_and_contexts
+                if e.event_id not in replaced_events
+            ]
+
+            if not events_and_contexts:
+                return replaced_events
 
         chunks = [
             events_and_contexts[x : x + 100]
@@ -441,6 +509,8 @@ class EventsPersistenceStorage:
 
             await self._handle_potentially_left_users(potentially_left_users)
 
+        return replaced_events
+
     async def _calculate_new_extremities(
         self,
         room_id: str,
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
new file mode 100644
index 0000000000..64e28bc639
--- /dev/null
+++ b/tests/handlers/test_message.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Tuple
+
+from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+class EventCreationTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.handler = self.hs.get_event_creation_handler()
+        self.persist_event_storage = self.hs.get_storage().persistence
+
+        self.user_id = self.register_user("tester", "foobar")
+        self.access_token = self.login("tester", "foobar")
+        self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+        self.info = self.get_success(
+            self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+        )
+        self.token_id = self.info["token_id"]
+
+        self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+
+    def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+        """Create a new event with the given transaction ID. All events produced
+        by this method will be considered duplicates.
+        """
+
+        # We create a new event with a random body, as otherwise we'll produce
+        # *exactly* the same event with the same hash, and so same event ID.
+        return self.get_success(
+            self.handler.create_event(
+                self.requester,
+                {
+                    "type": EventTypes.Message,
+                    "room_id": self.room_id,
+                    "sender": self.requester.user.to_string(),
+                    "content": {"msgtype": "m.text", "body": random_string(5)},
+                },
+                token_id=self.token_id,
+                txn_id=txn_id,
+            )
+        )
+
+    def test_duplicated_txn_id(self):
+        """Test that attempting to handle/persist an event with a transaction ID
+        that has already been persisted correctly returns the old event and does
+        *not* produce duplicate messages.
+        """
+
+        txn_id = "something_suitably_random"
+
+        event1, context = self._create_duplicate_event(txn_id)
+
+        ret_event1 = self.get_success(
+            self.handler.handle_new_client_event(self.requester, event1, context)
+        )
+        stream_id1 = ret_event1.internal_metadata.stream_ordering
+
+        self.assertEqual(event1.event_id, ret_event1.event_id)
+
+        event2, context = self._create_duplicate_event(txn_id)
+
+        # We want to test that the deduplication at the persit event end works,
+        # so we want to make sure we test with different events.
+        self.assertNotEqual(event1.event_id, event2.event_id)
+
+        ret_event2 = self.get_success(
+            self.handler.handle_new_client_event(self.requester, event2, context)
+        )
+        stream_id2 = ret_event2.internal_metadata.stream_ordering
+
+        # Assert that the returned values match those from the initial event
+        # rather than the new one.
+        self.assertEqual(ret_event1.event_id, ret_event2.event_id)
+        self.assertEqual(stream_id1, stream_id2)
+
+        # Let's test that calling `persist_event` directly also does the right
+        # thing.
+        event3, context = self._create_duplicate_event(txn_id)
+        self.assertNotEqual(event1.event_id, event3.event_id)
+
+        ret_event3, event_pos3, _ = self.get_success(
+            self.persist_event_storage.persist_event(event3, context)
+        )
+
+        # Assert that the returned values match those from the initial event
+        # rather than the new one.
+        self.assertEqual(ret_event1.event_id, ret_event3.event_id)
+        self.assertEqual(stream_id1, event_pos3.stream)
+
+        # Let's test that calling `persist_events` directly also does the right
+        # thing.
+        event4, context = self._create_duplicate_event(txn_id)
+        self.assertNotEqual(event1.event_id, event3.event_id)
+
+        events, _ = self.get_success(
+            self.persist_event_storage.persist_events([(event3, context)])
+        )
+        ret_event4 = events[0]
+
+        # Assert that the returned values match those from the initial event
+        # rather than the new one.
+        self.assertEqual(ret_event1.event_id, ret_event4.event_id)
+
+    def test_duplicated_txn_id_one_call(self):
+        """Test that we correctly handle duplicates that we try and persist at
+        the same time.
+        """
+
+        txn_id = "something_else_suitably_random"
+
+        # Create two duplicate events to persist at the same time
+        event1, context1 = self._create_duplicate_event(txn_id)
+        event2, context2 = self._create_duplicate_event(txn_id)
+
+        # Ensure their event IDs are different to start with
+        self.assertNotEqual(event1.event_id, event2.event_id)
+
+        events, _ = self.get_success(
+            self.persist_event_storage.persist_events(
+                [(event1, context1), (event2, context2)]
+            )
+        )
+
+        # Check that we've deduplicated the events.
+        self.assertEqual(len(events), 2)
+        self.assertEqual(events[0].event_id, events[1].event_id)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index d03e121664..b737625e33 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -107,7 +107,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
 
         request, channel = self.make_request(
             "PUT",
-            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
+            "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
             {},
             access_token=self.tok,
         )
diff --git a/tests/unittest.py b/tests/unittest.py
index 5c87f6097e..6c1661c92c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -254,17 +254,24 @@ class HomeserverTestCase(TestCase):
         if hasattr(self, "user_id"):
             if self.hijack_auth:
 
+                # We need a valid token ID to satisfy foreign key constraints.
+                token_id = self.get_success(
+                    self.hs.get_datastore().add_access_token_to_user(
+                        self.helper.auth_user_id, "some_fake_token", None, None,
+                    )
+                )
+
                 async def get_user_by_access_token(token=None, allow_guest=False):
                     return {
                         "user": UserID.from_string(self.helper.auth_user_id),
-                        "token_id": 1,
+                        "token_id": token_id,
                         "is_guest": False,
                     }
 
                 async def get_user_by_req(request, allow_guest=False, rights="access"):
                     return create_requester(
                         UserID.from_string(self.helper.auth_user_id),
-                        1,
+                        token_id,
                         False,
                         False,
                         None,