summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.buildkite/worker-blacklist4
-rw-r--r--changelog.d/5788.bugfix1
-rw-r--r--changelog.d/5826.misc1
-rw-r--r--changelog.d/5839.bugfix1
-rw-r--r--changelog.d/5843.misc1
-rw-r--r--contrib/purge_api/purge_remote_media.sh2
-rw-r--r--synapse/storage/events.py270
-rw-r--r--synapse/storage/events_worker.py213
-rw-r--r--tests/storage/test_redaction.py70
9 files changed, 329 insertions, 234 deletions
diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist
index 8ed8eef1a3..cda5c84e94 100644
--- a/.buildkite/worker-blacklist
+++ b/.buildkite/worker-blacklist
@@ -3,10 +3,6 @@
 
 Message history can be paginated
 
-m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users
-
-m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users
-
 Can re-join room if re-invited
 
 /upgrade creates a new room
diff --git a/changelog.d/5788.bugfix b/changelog.d/5788.bugfix
new file mode 100644
index 0000000000..5632f3cb99
--- /dev/null
+++ b/changelog.d/5788.bugfix
@@ -0,0 +1 @@
+Correctly handle redactions of redactions.
diff --git a/changelog.d/5826.misc b/changelog.d/5826.misc
new file mode 100644
index 0000000000..9abed11bbe
--- /dev/null
+++ b/changelog.d/5826.misc
@@ -0,0 +1 @@
+Reduce global pauses in the events stream caused by expensive state resolution during persistence.
diff --git a/changelog.d/5839.bugfix b/changelog.d/5839.bugfix
new file mode 100644
index 0000000000..5775bfa653
--- /dev/null
+++ b/changelog.d/5839.bugfix
@@ -0,0 +1 @@
+The purge_remote_media.sh script was fixed.
diff --git a/changelog.d/5843.misc b/changelog.d/5843.misc
new file mode 100644
index 0000000000..e7e7d572b7
--- /dev/null
+++ b/changelog.d/5843.misc
@@ -0,0 +1 @@
+Whitelist history visbility sytests in worker mode tests.
diff --git a/contrib/purge_api/purge_remote_media.sh b/contrib/purge_api/purge_remote_media.sh
index 99c07c663d..77220d3bd5 100644
--- a/contrib/purge_api/purge_remote_media.sh
+++ b/contrib/purge_api/purge_remote_media.sh
@@ -51,4 +51,4 @@ TOKEN=$(sql "SELECT token FROM access_tokens WHERE user_id='$ADMIN' ORDER BY id
 # finally start pruning media:
 ###############################################################################
 set -x # for debugging the generated string
-curl --header "Authorization: Bearer $TOKEN" -v POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP"
+curl --header "Authorization: Bearer $TOKEN" -X POST "$API_URL/admin/purge_media_cache/?before_ts=$UNIX_TIMESTAMP"
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 88c0180116..ac876287fc 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -364,147 +364,161 @@ class EventsStore(
         if not events_and_contexts:
             return
 
-        if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
-                len(events_and_contexts)
-            )
-        else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
-                len(events_and_contexts)
-            )
-
-        with stream_ordering_manager as stream_orderings:
-            for (event, context), stream in zip(events_and_contexts, stream_orderings):
-                event.internal_metadata.stream_ordering = stream
-
-            chunks = [
-                events_and_contexts[x : x + 100]
-                for x in range(0, len(events_and_contexts), 100)
-            ]
-
-            for chunk in chunks:
-                # We can't easily parallelize these since different chunks
-                # might contain the same event. :(
+        chunks = [
+            events_and_contexts[x : x + 100]
+            for x in range(0, len(events_and_contexts), 100)
+        ]
 
-                # NB: Assumes that we are only persisting events for one room
-                # at a time.
+        for chunk in chunks:
+            # We can't easily parallelize these since different chunks
+            # might contain the same event. :(
 
-                # map room_id->list[event_ids] giving the new forward
-                # extremities in each room
-                new_forward_extremeties = {}
+            # NB: Assumes that we are only persisting events for one room
+            # at a time.
 
-                # map room_id->(type,state_key)->event_id tracking the full
-                # state in each room after adding these events.
-                # This is simply used to prefill the get_current_state_ids
-                # cache
-                current_state_for_room = {}
+            # map room_id->list[event_ids] giving the new forward
+            # extremities in each room
+            new_forward_extremeties = {}
 
-                # map room_id->(to_delete, to_insert) where to_delete is a list
-                # of type/state keys to remove from current state, and to_insert
-                # is a map (type,key)->event_id giving the state delta in each
-                # room
-                state_delta_for_room = {}
+            # map room_id->(type,state_key)->event_id tracking the full
+            # state in each room after adding these events.
+            # This is simply used to prefill the get_current_state_ids
+            # cache
+            current_state_for_room = {}
 
-                if not backfilled:
-                    with Measure(self._clock, "_calculate_state_and_extrem"):
-                        # Work out the new "current state" for each room.
-                        # We do this by working out what the new extremities are and then
-                        # calculating the state from that.
-                        events_by_room = {}
-                        for event, context in chunk:
-                            events_by_room.setdefault(event.room_id, []).append(
-                                (event, context)
-                            )
+            # map room_id->(to_delete, to_insert) where to_delete is a list
+            # of type/state keys to remove from current state, and to_insert
+            # is a map (type,key)->event_id giving the state delta in each
+            # room
+            state_delta_for_room = {}
 
-                        for room_id, ev_ctx_rm in iteritems(events_by_room):
-                            latest_event_ids = yield self.get_latest_event_ids_in_room(
-                                room_id
-                            )
-                            new_latest_event_ids = yield self._calculate_new_extremities(
-                                room_id, ev_ctx_rm, latest_event_ids
+            if not backfilled:
+                with Measure(self._clock, "_calculate_state_and_extrem"):
+                    # Work out the new "current state" for each room.
+                    # We do this by working out what the new extremities are and then
+                    # calculating the state from that.
+                    events_by_room = {}
+                    for event, context in chunk:
+                        events_by_room.setdefault(event.room_id, []).append(
+                            (event, context)
+                        )
+
+                    for room_id, ev_ctx_rm in iteritems(events_by_room):
+                        latest_event_ids = yield self.get_latest_event_ids_in_room(
+                            room_id
+                        )
+                        new_latest_event_ids = yield self._calculate_new_extremities(
+                            room_id, ev_ctx_rm, latest_event_ids
+                        )
+
+                        latest_event_ids = set(latest_event_ids)
+                        if new_latest_event_ids == latest_event_ids:
+                            # No change in extremities, so no change in state
+                            continue
+
+                        # there should always be at least one forward extremity.
+                        # (except during the initial persistence of the send_join
+                        # results, in which case there will be no existing
+                        # extremities, so we'll `continue` above and skip this bit.)
+                        assert new_latest_event_ids, "No forward extremities left!"
+
+                        new_forward_extremeties[room_id] = new_latest_event_ids
+
+                        len_1 = (
+                            len(latest_event_ids) == 1
+                            and len(new_latest_event_ids) == 1
+                        )
+                        if len_1:
+                            all_single_prev_not_state = all(
+                                len(event.prev_event_ids()) == 1
+                                and not event.is_state()
+                                for event, ctx in ev_ctx_rm
                             )
-
-                            latest_event_ids = set(latest_event_ids)
-                            if new_latest_event_ids == latest_event_ids:
-                                # No change in extremities, so no change in state
+                            # Don't bother calculating state if they're just
+                            # a long chain of single ancestor non-state events.
+                            if all_single_prev_not_state:
                                 continue
 
-                            # there should always be at least one forward extremity.
-                            # (except during the initial persistence of the send_join
-                            # results, in which case there will be no existing
-                            # extremities, so we'll `continue` above and skip this bit.)
-                            assert new_latest_event_ids, "No forward extremities left!"
-
-                            new_forward_extremeties[room_id] = new_latest_event_ids
-
-                            len_1 = (
-                                len(latest_event_ids) == 1
-                                and len(new_latest_event_ids) == 1
+                        state_delta_counter.inc()
+                        if len(new_latest_event_ids) == 1:
+                            state_delta_single_event_counter.inc()
+
+                            # This is a fairly handwavey check to see if we could
+                            # have guessed what the delta would have been when
+                            # processing one of these events.
+                            # What we're interested in is if the latest extremities
+                            # were the same when we created the event as they are
+                            # now. When this server creates a new event (as opposed
+                            # to receiving it over federation) it will use the
+                            # forward extremities as the prev_events, so we can
+                            # guess this by looking at the prev_events and checking
+                            # if they match the current forward extremities.
+                            for ev, _ in ev_ctx_rm:
+                                prev_event_ids = set(ev.prev_event_ids())
+                                if latest_event_ids == prev_event_ids:
+                                    state_delta_reuse_delta_counter.inc()
+                                    break
+
+                        logger.info("Calculating state delta for room %s", room_id)
+                        with Measure(
+                            self._clock, "persist_events.get_new_state_after_events"
+                        ):
+                            res = yield self._get_new_state_after_events(
+                                room_id,
+                                ev_ctx_rm,
+                                latest_event_ids,
+                                new_latest_event_ids,
                             )
-                            if len_1:
-                                all_single_prev_not_state = all(
-                                    len(event.prev_event_ids()) == 1
-                                    and not event.is_state()
-                                    for event, ctx in ev_ctx_rm
-                                )
-                                # Don't bother calculating state if they're just
-                                # a long chain of single ancestor non-state events.
-                                if all_single_prev_not_state:
-                                    continue
-
-                            state_delta_counter.inc()
-                            if len(new_latest_event_ids) == 1:
-                                state_delta_single_event_counter.inc()
-
-                                # This is a fairly handwavey check to see if we could
-                                # have guessed what the delta would have been when
-                                # processing one of these events.
-                                # What we're interested in is if the latest extremities
-                                # were the same when we created the event as they are
-                                # now. When this server creates a new event (as opposed
-                                # to receiving it over federation) it will use the
-                                # forward extremities as the prev_events, so we can
-                                # guess this by looking at the prev_events and checking
-                                # if they match the current forward extremities.
-                                for ev, _ in ev_ctx_rm:
-                                    prev_event_ids = set(ev.prev_event_ids())
-                                    if latest_event_ids == prev_event_ids:
-                                        state_delta_reuse_delta_counter.inc()
-                                        break
-
-                            logger.info("Calculating state delta for room %s", room_id)
+                            current_state, delta_ids = res
+
+                        # If either are not None then there has been a change,
+                        # and we need to work out the delta (or use that
+                        # given)
+                        if delta_ids is not None:
+                            # If there is a delta we know that we've
+                            # only added or replaced state, never
+                            # removed keys entirely.
+                            state_delta_for_room[room_id] = ([], delta_ids)
+                        elif current_state is not None:
                             with Measure(
-                                self._clock, "persist_events.get_new_state_after_events"
+                                self._clock, "persist_events.calculate_state_delta"
                             ):
-                                res = yield self._get_new_state_after_events(
-                                    room_id,
-                                    ev_ctx_rm,
-                                    latest_event_ids,
-                                    new_latest_event_ids,
+                                delta = yield self._calculate_state_delta(
+                                    room_id, current_state
                                 )
-                                current_state, delta_ids = res
-
-                            # If either are not None then there has been a change,
-                            # and we need to work out the delta (or use that
-                            # given)
-                            if delta_ids is not None:
-                                # If there is a delta we know that we've
-                                # only added or replaced state, never
-                                # removed keys entirely.
-                                state_delta_for_room[room_id] = ([], delta_ids)
-                            elif current_state is not None:
-                                with Measure(
-                                    self._clock, "persist_events.calculate_state_delta"
-                                ):
-                                    delta = yield self._calculate_state_delta(
-                                        room_id, current_state
-                                    )
-                                state_delta_for_room[room_id] = delta
-
-                            # If we have the current_state then lets prefill
-                            # the cache with it.
-                            if current_state is not None:
-                                current_state_for_room[room_id] = current_state
+                            state_delta_for_room[room_id] = delta
+
+                        # If we have the current_state then lets prefill
+                        # the cache with it.
+                        if current_state is not None:
+                            current_state_for_room[room_id] = current_state
+
+            # We want to calculate the stream orderings as late as possible, as
+            # we only notify after all events with a lesser stream ordering have
+            # been persisted. I.e. if we spend 10s inside the with block then
+            # that will delay all subsequent events from being notified about.
+            # Hence why we do it down here rather than wrapping the entire
+            # function.
+            #
+            # Its safe to do this after calculating the state deltas etc as we
+            # only need to protect the *persistence* of the events. This is to
+            # ensure that queries of the form "fetch events since X" don't
+            # return events and stream positions after events that are still in
+            # flight, as otherwise subsequent requests "fetch event since Y"
+            # will not return those events.
+            #
+            # Note: Multiple instances of this function cannot be in flight at
+            # the same time for the same room.
+            if backfilled:
+                stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+                    len(chunk)
+                )
+            else:
+                stream_ordering_manager = self._stream_id_gen.get_next_mult(len(chunk))
+
+            with stream_ordering_manager as stream_orderings:
+                for (event, context), stream in zip(chunk, stream_orderings):
+                    event.internal_metadata.stream_ordering = stream
 
                 yield self.runInteraction(
                     "persist_events",
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 79680ee856..c6fa7f82fd 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -29,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
 from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.events.utils import prune_event
-from synapse.logging.context import (
-    LoggingContext,
-    PreserveLoggingContext,
-    make_deferred_yieldable,
-    run_in_background,
-)
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
@@ -342,13 +337,12 @@ class EventsWorkerStore(SQLBaseStore):
             log_ctx = LoggingContext.current_context()
             log_ctx.record_event_fetch(len(missing_events_ids))
 
-            # Note that _enqueue_events is also responsible for turning db rows
+            # Note that _get_events_from_db is also responsible for turning db rows
             # into FrozenEvents (via _get_event_from_row), which involves seeing if
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            # _enqueue_events is a bit of a rubbish name but naming is hard.
-            missing_events = yield self._enqueue_events(
+            missing_events = yield self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -421,28 +415,28 @@ class EventsWorkerStore(SQLBaseStore):
                 The fetch requests. Each entry consists of a list of event
                 ids to be fetched, and a deferred to be completed once the
                 events have been fetched.
+
+                The deferreds are callbacked with a dictionary mapping from event id
+                to event row. Note that it may well contain additional events that
+                were not part of this request.
         """
         with Measure(self._clock, "_fetch_event_list"):
             try:
-                event_id_lists = list(zip(*event_list))[0]
-                event_ids = [item for sublist in event_id_lists for item in sublist]
+                events_to_fetch = set(
+                    event_id for events, _ in event_list for event_id in events
+                )
 
                 row_dict = self._new_transaction(
-                    conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
+                    conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
                 )
 
                 # We only want to resolve deferreds from the main thread
-                def fire(lst, res):
-                    for ids, d in lst:
-                        if not d.called:
-                            try:
-                                with PreserveLoggingContext():
-                                    d.callback([res[i] for i in ids if i in res])
-                            except Exception:
-                                logger.exception("Failed to callback")
+                def fire():
+                    for _, d in event_list:
+                        d.callback(row_dict)
 
                 with PreserveLoggingContext():
-                    self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
+                    self.hs.get_reactor().callFromThread(fire)
             except Exception as e:
                 logger.exception("do_fetch")
 
@@ -457,13 +451,98 @@ class EventsWorkerStore(SQLBaseStore):
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
     @defer.inlineCallbacks
-    def _enqueue_events(self, events, allow_rejected=False):
+    def _get_events_from_db(self, event_ids, allow_rejected=False):
+        """Fetch a bunch of events from the database.
+
+        Returned events will be added to the cache for future lookups.
+
+        Args:
+            event_ids (Iterable[str]): The event_ids of the events to fetch
+            allow_rejected (bool): Whether to include rejected events
+
+        Returns:
+            Deferred[Dict[str, _EventCacheEntry]]:
+                map from event id to result. May return extra events which
+                weren't asked for.
+        """
+        fetched_events = {}
+        events_to_fetch = event_ids
+
+        while events_to_fetch:
+            row_map = yield self._enqueue_events(events_to_fetch)
+
+            # we need to recursively fetch any redactions of those events
+            redaction_ids = set()
+            for event_id in events_to_fetch:
+                row = row_map.get(event_id)
+                fetched_events[event_id] = row
+                if row:
+                    redaction_ids.update(row["redactions"])
+
+            events_to_fetch = redaction_ids.difference(fetched_events.keys())
+            if events_to_fetch:
+                logger.debug("Also fetching redaction events %s", events_to_fetch)
+
+        # build a map from event_id to EventBase
+        event_map = {}
+        for event_id, row in fetched_events.items():
+            if not row:
+                continue
+            assert row["event_id"] == event_id
+
+            rejected_reason = row["rejected_reason"]
+
+            if not allow_rejected and rejected_reason:
+                continue
+
+            d = json.loads(row["json"])
+            internal_metadata = json.loads(row["internal_metadata"])
+
+            format_version = row["format_version"]
+            if format_version is None:
+                # This means that we stored the event before we had the concept
+                # of a event format version, so it must be a V1 event.
+                format_version = EventFormatVersions.V1
+
+            original_ev = event_type_from_format_version(format_version)(
+                event_dict=d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
+            )
+
+            event_map[event_id] = original_ev
+
+        # finally, we can decide whether each one nededs redacting, and build
+        # the cache entries.
+        result_map = {}
+        for event_id, original_ev in event_map.items():
+            redactions = fetched_events[event_id]["redactions"]
+            redacted_event = self._maybe_redact_event_row(
+                original_ev, redactions, event_map
+            )
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev, redacted_event=redacted_event
+            )
+
+            self._get_event_cache.prefill((event_id,), cache_entry)
+            result_map[event_id] = cache_entry
+
+        return result_map
+
+    @defer.inlineCallbacks
+    def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
+
+        Args:
+            events (Iterable[str]): events to be fetched.
+
+        Returns:
+            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+                May contain events that weren't requested.
         """
-        if not events:
-            return {}
 
         events_d = defer.Deferred()
         with self._event_fetch_lock:
@@ -482,32 +561,12 @@ class EventsWorkerStore(SQLBaseStore):
                 "fetch_events", self.runWithConnection, self._do_fetch
             )
 
-        logger.debug("Loading %d events", len(events))
+        logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            rows = yield events_d
-        logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
-        if not allow_rejected:
-            rows[:] = [r for r in rows if r["rejected_reason"] is None]
-
-        res = yield make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self._get_event_from_row,
-                        row["internal_metadata"],
-                        row["json"],
-                        row["redactions"],
-                        rejected_reason=row["rejected_reason"],
-                        format_version=row["format_version"],
-                    )
-                    for row in rows
-                ],
-                consumeErrors=True,
-            )
-        )
+            row_map = yield events_d
+        logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
-        return {e.event.event_id: e for e in res if e}
+        return row_map
 
     def _fetch_event_rows(self, txn, event_ids):
         """Fetch event rows from the database
@@ -580,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_dict
 
-    @defer.inlineCallbacks
-    def _get_event_from_row(
-        self, internal_metadata, js, redactions, format_version, rejected_reason=None
-    ):
-        """Parse an event row which has been read from the database
-
-        Args:
-            internal_metadata (str): json-encoded internal_metadata column
-            js (str): json-encoded event body from event_json
-            redactions (list[str]): a list of the events which claim to have redacted
-                this event, from the redactions table
-            format_version: (str): the 'format_version' column
-            rejected_reason (str|None): the reason this event was rejected, if any
-
-        Returns:
-            _EventCacheEntry
-        """
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if format_version is None:
-                # This means that we stored the event before we had the concept
-                # of a event format version, so it must be a V1 event.
-                format_version = EventFormatVersions.V1
-
-            original_ev = event_type_from_format_version(format_version)(
-                event_dict=d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev, redacted_event=redacted_event
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        return cache_entry
-
-    @defer.inlineCallbacks
-    def _maybe_redact_event_row(self, original_ev, redactions):
+    def _maybe_redact_event_row(self, original_ev, redactions, event_map):
         """Given an event object and a list of possible redacting event ids,
         determine whether to honour any of those redactions and if so return a redacted
         event.
@@ -631,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore):
         Args:
              original_ev (EventBase):
              redactions (iterable[str]): list of event ids of potential redaction events
+             event_map (dict[str, EventBase]): other events which have been fetched, in
+                 which we can look up the redaaction events. Map from event id to event.
 
         Returns:
             Deferred[EventBase|None]: if the event should be redacted, a pruned
@@ -640,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore):
             # we choose to ignore redactions of m.room.create events.
             return None
 
-        if original_ev.type == "m.room.redaction":
-            # ... and redaction events
-            return None
-
-        redaction_map = yield self._get_events_from_cache_or_db(redactions)
-
         for redaction_id in redactions:
-            redaction_entry = redaction_map.get(redaction_id)
-            if not redaction_entry:
+            redaction_event = event_map.get(redaction_id)
+            if not redaction_event or redaction_event.rejected_reason:
                 # we don't have the redaction event, or the redaction event was not
                 # authorized.
                 logger.debug(
@@ -658,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore):
                 )
                 continue
 
-            redaction_event = redaction_entry.event
             if redaction_event.room_id != original_ev.room_id:
                 logger.debug(
                     "%s was redacted by %s but redaction was in a different room!",
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 8488b6edc8..d961b81d48 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -17,6 +17,8 @@
 
 from mock import Mock
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomID, UserID
@@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
             event.unsigned["redacted_because"],
         )
+
+    def test_circular_redaction(self):
+        redaction_event_id1 = "$redaction1_id:test"
+        redaction_event_id2 = "$redaction2_id:test"
+
+        class EventIdManglingBuilder:
+            def __init__(self, base_builder, event_id):
+                self._base_builder = base_builder
+                self._event_id = event_id
+
+            @defer.inlineCallbacks
+            def build(self, prev_event_ids):
+                built_event = yield self._base_builder.build(prev_event_ids)
+                built_event.event_id = self._event_id
+                built_event._event_dict["event_id"] = self._event_id
+                return built_event
+
+            @property
+            def room_id(self):
+                return self._base_builder.room_id
+
+        event_1, context_1 = self.get_success(
+            self.event_creation_handler.create_new_client_event(
+                EventIdManglingBuilder(
+                    self.event_builder_factory.for_room_version(
+                        RoomVersions.V1,
+                        {
+                            "type": EventTypes.Redaction,
+                            "sender": self.u_alice.to_string(),
+                            "room_id": self.room1.to_string(),
+                            "content": {"reason": "test"},
+                            "redacts": redaction_event_id2,
+                        },
+                    ),
+                    redaction_event_id1,
+                )
+            )
+        )
+
+        self.get_success(self.store.persist_event(event_1, context_1))
+
+        event_2, context_2 = self.get_success(
+            self.event_creation_handler.create_new_client_event(
+                EventIdManglingBuilder(
+                    self.event_builder_factory.for_room_version(
+                        RoomVersions.V1,
+                        {
+                            "type": EventTypes.Redaction,
+                            "sender": self.u_alice.to_string(),
+                            "room_id": self.room1.to_string(),
+                            "content": {"reason": "test"},
+                            "redacts": redaction_event_id1,
+                        },
+                    ),
+                    redaction_event_id2,
+                )
+            )
+        )
+        self.get_success(self.store.persist_event(event_2, context_2))
+
+        # fetch one of the redactions
+        fetched = self.get_success(self.store.get_event(redaction_event_id1))
+
+        # it should have been redacted
+        self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
+        self.assertEqual(
+            fetched.unsigned["redacted_because"].event_id, redaction_event_id2
+        )