summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/builder.py4
-rw-r--r--synapse/handlers/sync.py8
-rw-r--r--synapse/state/__init__.py35
-rw-r--r--synapse/storage/state.py49
-rw-r--r--synapse/visibility.py1
5 files changed, 78 insertions, 19 deletions
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 98c203ada0..68ab2113d4 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -120,8 +120,10 @@ class EventBuilder:
             The signed and hashed event.
         """
         if auth_event_ids is None:
+            # we pick the auth events based on our best knowledge of the current state
+            # of the room, so we don't need to await full state.
             state_ids = await self._state.get_current_state_ids(
-                self.room_id, prev_event_ids
+                self.room_id, prev_event_ids, await_full_state=False
             )
             auth_event_ids = self._event_auth_handler.compute_auth_events(
                 self, state_ids
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 59b5d497be..ed0e8a9fe6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -902,11 +902,15 @@ class SyncHandler:
             if full_state:
                 if batch:
                     current_state_ids = await self.state_store.get_state_ids_for_event(
-                        batch.events[-1].event_id, state_filter=state_filter
+                        batch.events[-1].event_id,
+                        state_filter=state_filter,
+                        await_full_state=not lazy_load_members,  # TODO
                     )
 
                     state_ids = await self.state_store.get_state_ids_for_event(
-                        batch.events[0].event_id, state_filter=state_filter
+                        batch.events[0].event_id,
+                        state_filter=state_filter,
+                        await_full_state=not lazy_load_members,  # TODO
                     )
 
                 else:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4b4ed42cff..098b5f32ff 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -48,6 +48,7 @@ from synapse.logging.context import ContextResourceUsage
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -177,7 +178,16 @@ class StateHandler:
         assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+
+        filter = StateFilter.all()
+        if event_type:
+            filter = StateFilter.from_types(((event_type, state_key),))
+
+        ret = await self.resolve_state_groups_for_events(
+            room_id,
+            latest_event_ids,
+            await_full_state=filter.must_await_full_state(self.hs.is_mine_id),
+        )
         state = ret.state
 
         if event_type:
@@ -195,7 +205,10 @@ class StateHandler:
         }
 
     async def get_current_state_ids(
-        self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
+        self,
+        room_id: str,
+        latest_event_ids: Optional[Collection[str]] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the current state, or the state at a set of events, for a room
 
@@ -203,6 +216,8 @@ class StateHandler:
             room_id:
             latest_event_ids: if given, the forward extremities to resolve. If
                 None, we look them up from the database (via a cache).
+            await_full_state: if true, will block if we do not yet have complete
+               state at the latest events.
 
         Returns:
             the state dict, mapping from (event_type, state_key) -> event_id
@@ -212,7 +227,9 @@ class StateHandler:
         assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+        ret = await self.resolve_state_groups_for_events(
+            room_id, latest_event_ids, await_full_state=await_full_state
+        )
         return ret.state
 
     async def get_current_users_in_room(
@@ -323,7 +340,9 @@ class StateHandler:
 
             logger.debug("calling resolve_state_groups from compute_event_context")
             entry = await self.resolve_state_groups_for_events(
-                event.room_id, event.prev_event_ids()
+                event.room_id,
+                event.prev_event_ids(),
+                await_full_state=False,
             )
 
             state_ids_before_event = entry.state
@@ -404,7 +423,7 @@ class StateHandler:
 
     @measure_func()
     async def resolve_state_groups_for_events(
-        self, room_id: str, event_ids: Collection[str]
+        self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
     ) -> _StateCacheEntry:
         """Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
@@ -412,13 +431,17 @@ class StateHandler:
         Args:
             room_id
             event_ids
+            await_full_state: if true, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             The resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        state_groups = await self.state_store.get_state_group_for_events(event_ids)
+        state_groups = await self.state_store.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         state_group_ids = state_groups.values()
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e58301a8f0..a7e721aef2 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -609,13 +609,18 @@ class StateGroupStorage:
         return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
-        self, _room_id: str, event_ids: Collection[str]
+        self,
+        _room_id: str,
+        event_ids: Collection[str],
+        await_full_state: bool = True,
     ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
             _room_id: id of the room for these events
             event_ids: ids of the events
+            await_full_state: if true, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -627,7 +632,9 @@ class StateGroupStorage:
         if not event_ids:
             return {}
 
-        event_to_groups = await self.get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -700,7 +707,10 @@ class StateGroupStorage:
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     async def get_state_for_events(
-        self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+        self,
+        event_ids: Collection[str],
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
@@ -708,6 +718,8 @@ class StateGroupStorage:
         Args:
             event_ids: The events to fetch the state of.
             state_filter: The state filter used to fetch state.
+            await_full_state: if true, will block if the state_filter includes state
+               which is not yet complete.
 
         Returns:
             A dict of (event_id) -> (type, state_key) -> [state_events]
@@ -716,8 +728,11 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                (ie they are outliers or unknown)
         """
-        await_full_state = True
-        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+        if (
+            await_full_state
+            and state_filter
+            and not state_filter.must_await_full_state(self._is_mine_id)
+        ):
             await_full_state = False
 
         event_to_groups = await self.get_state_group_for_events(
@@ -749,6 +764,7 @@ class StateGroupStorage:
         self,
         event_ids: Collection[str],
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -757,6 +773,8 @@ class StateGroupStorage:
         Args:
             event_ids: events whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if true, will block if the state_filter includes state
+               which is not yet complete.
 
         Returns:
             A dict from event_id -> (type, state_key) -> event_id
@@ -765,8 +783,12 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        await_full_state = True
-        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+
+        if (
+            await_full_state
+            and state_filter
+            and not state_filter.must_await_full_state(self._is_mine_id)
+        ):
             await_full_state = False
 
         event_to_groups = await self.get_state_group_for_events(
@@ -808,7 +830,10 @@ class StateGroupStorage:
         return state_map[event_id]
 
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -816,6 +841,8 @@ class StateGroupStorage:
         Args:
             event_id: event whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if true, will block if the state_filter includes state
+               which is not yet complete.
 
         Returns:
             A dict from (type, state_key) -> state_event_id
@@ -825,7 +852,9 @@ class StateGroupStorage:
                 outlier or is unknown)
         """
         state_map = await self.get_state_ids_for_events(
-            [event_id], state_filter or StateFilter.all()
+            [event_id],
+            state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
         return state_map[event_id]
 
@@ -857,7 +886,7 @@ class StateGroupStorage:
         Args:
             event_ids: events to get state groups for
             await_full_state: if true, will block if we do not yet have complete
-               state at these events.
+               state at these event.
         """
         if await_full_state:
             await self._partial_state_events_tracker.await_full_state(event_ids)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index de6d2ffc52..b851e660fd 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -85,6 +85,7 @@ async def filter_events_for_client(
     event_id_to_state = await storage.state.get_state_for_events(
         frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
         state_filter=StateFilter.from_types(types),
+        await_full_state=False,
     )
 
     # Get the users who are ignored by the requesting user.