summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/events_worker.py33
1 files changed, 19 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 01e935edef..318fd7dc71 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
 import threading
 import weakref
 from enum import Enum, auto
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     Any,
     Collection,
-    Container,
     Dict,
     Iterable,
     List,
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
     async def get_stripped_room_state_from_event_context(
         self,
         context: EventContext,
-        state_types_to_include: Container[str],
+        state_keys_to_include: StateFilter,
         membership_user_id: Optional[str] = None,
     ) -> List[JsonDict]:
         """
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
             context: The event context to retrieve state of the room from.
-            state_types_to_include: The type of state events to include.
+            state_keys_to_include: The state events to include, for each event type.
             membership_user_id: An optional user ID to include the stripped membership state
                 events of. This is useful when generating the stripped state of a room for
                 invites. We want to send membership events of the inviter, so that the
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             A list of dictionaries, each representing a stripped state event from the room.
         """
-        current_state_ids = await context.get_current_state_ids()
+        if membership_user_id:
+            types = chain(
+                state_keys_to_include.to_types(),
+                [(EventTypes.Member, membership_user_id)],
+            )
+            filter = StateFilter.from_types(types)
+        else:
+            filter = state_keys_to_include
+        selected_state_ids = await context.get_current_state_ids(filter)
 
         # We know this event is not an outlier, so this must be
         # non-None.
-        assert current_state_ids is not None
-
-        # The state to include
-        state_to_include_ids = [
-            e_id
-            for k, e_id in current_state_ids.items()
-            if k[0] in state_types_to_include
-            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
-        ]
+        assert selected_state_ids is not None
+
+        # Confusingly, get_current_state_events may return events that are discarded by
+        # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+        selected_state_ids = filter.filter_state(selected_state_ids)
 
-        state_to_include = await self.get_events(state_to_include_ids)
+        state_to_include = await self.get_events(selected_state_ids.values())
 
         return [
             {