summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/events_worker.py96
-rw-r--r--tests/storage/test_redaction.py70
2 files changed, 106 insertions, 60 deletions
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index e15e7d86fe..c6fa7f82fd 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -483,7 +483,8 @@ class EventsWorkerStore(SQLBaseStore):
             if events_to_fetch:
                 logger.debug("Also fetching redaction events %s", events_to_fetch)
 
-        result_map = {}
+        # build a map from event_id to EventBase
+        event_map = {}
         for event_id, row in fetched_events.items():
             if not row:
                 continue
@@ -494,14 +495,37 @@ class EventsWorkerStore(SQLBaseStore):
             if not allow_rejected and rejected_reason:
                 continue
 
-            cache_entry = yield self._get_event_from_row(
-                row["internal_metadata"],
-                row["json"],
-                row["redactions"],
-                rejected_reason=row["rejected_reason"],
-                format_version=row["format_version"],
+            d = json.loads(row["json"])
+            internal_metadata = json.loads(row["internal_metadata"])
+
+            format_version = row["format_version"]
+            if format_version is None:
+                # This means that we stored the event before we had the concept
+                # of a event format version, so it must be a V1 event.
+                format_version = EventFormatVersions.V1
+
+            original_ev = event_type_from_format_version(format_version)(
+                event_dict=d,
+                internal_metadata_dict=internal_metadata,
+                rejected_reason=rejected_reason,
             )
 
+            event_map[event_id] = original_ev
+
+        # finally, we can decide whether each one nededs redacting, and build
+        # the cache entries.
+        result_map = {}
+        for event_id, original_ev in event_map.items():
+            redactions = fetched_events[event_id]["redactions"]
+            redacted_event = self._maybe_redact_event_row(
+                original_ev, redactions, event_map
+            )
+
+            cache_entry = _EventCacheEntry(
+                event=original_ev, redacted_event=redacted_event
+            )
+
+            self._get_event_cache.prefill((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
         return result_map
@@ -615,50 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_dict
 
-    @defer.inlineCallbacks
-    def _get_event_from_row(
-        self, internal_metadata, js, redactions, format_version, rejected_reason=None
-    ):
-        """Parse an event row which has been read from the database
-
-        Args:
-            internal_metadata (str): json-encoded internal_metadata column
-            js (str): json-encoded event body from event_json
-            redactions (list[str]): a list of the events which claim to have redacted
-                this event, from the redactions table
-            format_version: (str): the 'format_version' column
-            rejected_reason (str|None): the reason this event was rejected, if any
-
-        Returns:
-            _EventCacheEntry
-        """
-        with Measure(self._clock, "_get_event_from_row"):
-            d = json.loads(js)
-            internal_metadata = json.loads(internal_metadata)
-
-            if format_version is None:
-                # This means that we stored the event before we had the concept
-                # of a event format version, so it must be a V1 event.
-                format_version = EventFormatVersions.V1
-
-            original_ev = event_type_from_format_version(format_version)(
-                event_dict=d,
-                internal_metadata_dict=internal_metadata,
-                rejected_reason=rejected_reason,
-            )
-
-            redacted_event = yield self._maybe_redact_event_row(original_ev, redactions)
-
-            cache_entry = _EventCacheEntry(
-                event=original_ev, redacted_event=redacted_event
-            )
-
-            self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
-        return cache_entry
-
-    @defer.inlineCallbacks
-    def _maybe_redact_event_row(self, original_ev, redactions):
+    def _maybe_redact_event_row(self, original_ev, redactions, event_map):
         """Given an event object and a list of possible redacting event ids,
         determine whether to honour any of those redactions and if so return a redacted
         event.
@@ -666,6 +647,8 @@ class EventsWorkerStore(SQLBaseStore):
         Args:
              original_ev (EventBase):
              redactions (iterable[str]): list of event ids of potential redaction events
+             event_map (dict[str, EventBase]): other events which have been fetched, in
+                 which we can look up the redaaction events. Map from event id to event.
 
         Returns:
             Deferred[EventBase|None]: if the event should be redacted, a pruned
@@ -675,15 +658,9 @@ class EventsWorkerStore(SQLBaseStore):
             # we choose to ignore redactions of m.room.create events.
             return None
 
-        if original_ev.type == "m.room.redaction":
-            # ... and redaction events
-            return None
-
-        redaction_map = yield self._get_events_from_cache_or_db(redactions)
-
         for redaction_id in redactions:
-            redaction_entry = redaction_map.get(redaction_id)
-            if not redaction_entry:
+            redaction_event = event_map.get(redaction_id)
+            if not redaction_event or redaction_event.rejected_reason:
                 # we don't have the redaction event, or the redaction event was not
                 # authorized.
                 logger.debug(
@@ -693,7 +670,6 @@ class EventsWorkerStore(SQLBaseStore):
                 )
                 continue
 
-            redaction_event = redaction_entry.event
             if redaction_event.room_id != original_ev.room_id:
                 logger.debug(
                     "%s was redacted by %s but redaction was in a different room!",
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 8488b6edc8..d961b81d48 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -17,6 +17,8 @@
 
 from mock import Mock
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.types import RoomID, UserID
@@ -216,3 +218,71 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
             event.unsigned["redacted_because"],
         )
+
+    def test_circular_redaction(self):
+        redaction_event_id1 = "$redaction1_id:test"
+        redaction_event_id2 = "$redaction2_id:test"
+
+        class EventIdManglingBuilder:
+            def __init__(self, base_builder, event_id):
+                self._base_builder = base_builder
+                self._event_id = event_id
+
+            @defer.inlineCallbacks
+            def build(self, prev_event_ids):
+                built_event = yield self._base_builder.build(prev_event_ids)
+                built_event.event_id = self._event_id
+                built_event._event_dict["event_id"] = self._event_id
+                return built_event
+
+            @property
+            def room_id(self):
+                return self._base_builder.room_id
+
+        event_1, context_1 = self.get_success(
+            self.event_creation_handler.create_new_client_event(
+                EventIdManglingBuilder(
+                    self.event_builder_factory.for_room_version(
+                        RoomVersions.V1,
+                        {
+                            "type": EventTypes.Redaction,
+                            "sender": self.u_alice.to_string(),
+                            "room_id": self.room1.to_string(),
+                            "content": {"reason": "test"},
+                            "redacts": redaction_event_id2,
+                        },
+                    ),
+                    redaction_event_id1,
+                )
+            )
+        )
+
+        self.get_success(self.store.persist_event(event_1, context_1))
+
+        event_2, context_2 = self.get_success(
+            self.event_creation_handler.create_new_client_event(
+                EventIdManglingBuilder(
+                    self.event_builder_factory.for_room_version(
+                        RoomVersions.V1,
+                        {
+                            "type": EventTypes.Redaction,
+                            "sender": self.u_alice.to_string(),
+                            "room_id": self.room1.to_string(),
+                            "content": {"reason": "test"},
+                            "redacts": redaction_event_id1,
+                        },
+                    ),
+                    redaction_event_id2,
+                )
+            )
+        )
+        self.get_success(self.store.persist_event(event_2, context_2))
+
+        # fetch one of the redactions
+        fetched = self.get_success(self.store.get_event(redaction_event_id1))
+
+        # it should have been redacted
+        self.assertEqual(fetched.unsigned["redacted_by"], redaction_event_id2)
+        self.assertEqual(
+            fetched.unsigned["redacted_because"].event_id, redaction_event_id2
+        )