summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5699.bugfix1
-rw-r--r--synapse/storage/events_worker.py122
2 files changed, 75 insertions, 48 deletions
diff --git a/changelog.d/5699.bugfix b/changelog.d/5699.bugfix
new file mode 100644
index 0000000000..30d5e67f67
--- /dev/null
+++ b/changelog.d/5699.bugfix
@@ -0,0 +1 @@
+Fix some problems with authenticating redactions in recent room versions.
\ No newline at end of file
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 874d0a56bc..6d6bb1d5d3 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -218,37 +218,23 @@ class EventsWorkerStore(SQLBaseStore):
         if not event_ids:
             defer.returnValue([])
 
-        event_id_list = event_ids
-        event_ids = set(event_ids)
-
-        event_entry_map = self._get_events_from_cache(
-            event_ids, allow_rejected=allow_rejected
+        # there may be duplicates so we cast the list to a set
+        event_entry_map = yield self._get_events_from_cache_or_db(
+            set(event_ids), allow_rejected=allow_rejected
         )
 
-        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
-        if missing_events_ids:
-            log_ctx = LoggingContext.current_context()
-            log_ctx.record_event_fetch(len(missing_events_ids))
-
-            # Note that _enqueue_events is also responsible for turning db rows
-            # into FrozenEvents (via _get_event_from_row), which involves seeing if
-            # the events have been redacted, and if so pulling the redaction event out
-            # of the database to check it.
-            #
-            # _enqueue_events is a bit of a rubbish name but naming is hard.
-            missing_events = yield self._enqueue_events(
-                missing_events_ids, allow_rejected=allow_rejected
-            )
-
-            event_entry_map.update(missing_events)
-
         events = []
-        for event_id in event_id_list:
+        for event_id in event_ids:
             entry = event_entry_map.get(event_id, None)
             if not entry:
                 continue
 
+            if not allow_rejected:
+                assert not entry.event.rejected_reason, (
+                    "rejected event returned from _get_events_from_cache_or_db despite "
+                    "allow_rejected=False"
+                )
+
             # Starting in room version v3, some redactions need to be rechecked if we
             # didn't have the redacted event at the time, so we recheck on read
             # instead.
@@ -291,34 +277,74 @@ class EventsWorkerStore(SQLBaseStore):
                         # recheck.
                         entry.event.internal_metadata.recheck_redaction = False
                     else:
-                        # We don't have the event that is being redacted, so we
-                        # assume that the event isn't authorized for now. (If we
-                        # later receive the event, then we will always redact
-                        # it anyway, since we have this redaction)
+                        # We either don't have the event that is being redacted (so we
+                        # assume that the event isn't authorised for now), or the
+                        # senders don't match (so it will never be authorised). Either
+                        # way, we shouldn't return it.
+                        #
+                        # (If we later receive the event, then we will redact it anyway,
+                        # since we have this redaction)
                         continue
 
-            if allow_rejected or not entry.event.rejected_reason:
-                if check_redacted and entry.redacted_event:
-                    event = entry.redacted_event
-                else:
-                    event = entry.event
-
-                events.append(event)
-
-                if get_prev_content:
-                    if "replaces_state" in event.unsigned:
-                        prev = yield self.get_event(
-                            event.unsigned["replaces_state"],
-                            get_prev_content=False,
-                            allow_none=True,
-                        )
-                        if prev:
-                            event.unsigned = dict(event.unsigned)
-                            event.unsigned["prev_content"] = prev.content
-                            event.unsigned["prev_sender"] = prev.sender
+            if check_redacted and entry.redacted_event:
+                event = entry.redacted_event
+            else:
+                event = entry.event
+
+            events.append(event)
+
+            if get_prev_content:
+                if "replaces_state" in event.unsigned:
+                    prev = yield self.get_event(
+                        event.unsigned["replaces_state"],
+                        get_prev_content=False,
+                        allow_none=True,
+                    )
+                    if prev:
+                        event.unsigned = dict(event.unsigned)
+                        event.unsigned["prev_content"] = prev.content
+                        event.unsigned["prev_sender"] = prev.sender
 
         defer.returnValue(events)
 
+    @defer.inlineCallbacks
+    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+        """Fetch a bunch of events from the cache or the database.
+
+        If events are pulled from the database, they will be cached for future lookups.
+
+        Args:
+            event_ids (Iterable[str]): The event_ids of the events to fetch
+            allow_rejected (bool): Whether to include rejected events
+
+        Returns:
+            Deferred[Dict[str, _EventCacheEntry]]:
+                map from event id to result
+        """
+        event_entry_map = self._get_events_from_cache(
+            event_ids, allow_rejected=allow_rejected
+        )
+
+        missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+        if missing_events_ids:
+            log_ctx = LoggingContext.current_context()
+            log_ctx.record_event_fetch(len(missing_events_ids))
+
+            # Note that _enqueue_events is also responsible for turning db rows
+            # into FrozenEvents (via _get_event_from_row), which involves seeing if
+            # the events have been redacted, and if so pulling the redaction event out
+            # of the database to check it.
+            #
+            # _enqueue_events is a bit of a rubbish name but naming is hard.
+            missing_events = yield self._enqueue_events(
+                missing_events_ids, allow_rejected=allow_rejected
+            )
+
+            event_entry_map.update(missing_events)
+
+        return event_entry_map
+
     def _invalidate_get_event_cache(self, event_id):
         self._get_event_cache.invalidate((event_id,))
 
@@ -326,7 +352,7 @@ class EventsWorkerStore(SQLBaseStore):
         """Fetch events from the caches
 
         Args:
-            events (list(str)): list of event_ids to fetch
+            events (Iterable[str]): list of event_ids to fetch
             allow_rejected (bool): Whether to return events that were rejected
             update_metrics (bool): Whether to update the cache hit ratio metrics