summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13267.misc1
-rw-r--r--synapse/handlers/message.py7
-rw-r--r--synapse/state/__init__.py117
-rw-r--r--synapse/storage/controllers/__init__.py4
-rw-r--r--synapse/storage/controllers/persist_events.py12
-rw-r--r--synapse/storage/databases/main/roommember.py35
-rw-r--r--tests/storage/test_roommember.py55
7 files changed, 95 insertions, 136 deletions
diff --git a/changelog.d/13267.misc b/changelog.d/13267.misc
new file mode 100644
index 0000000000..a334414320
--- /dev/null
+++ b/changelog.d/13267.misc
@@ -0,0 +1 @@
+Don't pull out state in `compute_event_context` for unconflicted state.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1980e37dae..b5fede9496 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1444,7 +1444,12 @@ class EventCreationHandler:
             if state_entry.state_group in self._external_cache_joined_hosts_updates:
                 return
 
-            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+            state = await state_entry.get_state(
+                self._storage_controllers.state, StateFilter.all()
+            )
+            joined_hosts = await self.store.get_joined_hosts(
+                event.room_id, state, state_entry
+            )
 
             # Note that the expiry times must be larger than the expiry time in
             # _external_cache_joined_hosts_updates.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 781d9f06da..9f0a36652c 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -31,7 +31,6 @@ from typing import (
     Sequence,
     Set,
     Tuple,
-    Union,
 )
 
 import attr
@@ -47,6 +46,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -54,6 +54,7 @@ from synapse.util.metrics import Measure, measure_func
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.controllers import StateStorageController
     from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
@@ -83,17 +84,20 @@ def _gen_state_id() -> str:
 
 
 class _StateCacheEntry:
-    __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+    __slots__ = ["state", "state_group", "prev_group", "delta_ids"]
 
     def __init__(
         self,
-        state: StateMap[str],
+        state: Optional[StateMap[str]],
         state_group: Optional[int],
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ):
+        if state is None and state_group is None:
+            raise Exception("Either state or state group must be not None")
+
         # A map from (type, state_key) to event_id.
-        self.state = frozendict(state)
+        self.state = frozendict(state) if state is not None else None
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
@@ -102,20 +106,30 @@ class _StateCacheEntry:
         self.prev_group = prev_group
         self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
 
-        # The `state_id` is a unique ID we generate that can be used as ID for
-        # this collection of state. Usually this would be the same as the
-        # state group, but on worker instances we can't generate a new state
-        # group each time we resolve state, so we generate a separate one that
-        # isn't persisted and is used solely for caches.
-        # `state_id` is either a state_group (and so an int) or a string. This
-        # ensures we don't accidentally persist a state_id as a stateg_group
-        if state_group:
-            self.state_id: Union[str, int] = state_group
-        else:
-            self.state_id = _gen_state_id()
+    async def get_state(
+        self,
+        state_storage: "StateStorageController",
+        state_filter: Optional["StateFilter"] = None,
+    ) -> StateMap[str]:
+        """Get the state map for this entry, either from the in-memory state or
+        looking up the state group in the DB.
+        """
+
+        if self.state is not None:
+            return self.state
+
+        assert self.state_group is not None
+
+        return await state_storage.get_state_ids_for_group(
+            self.state_group, state_filter
+        )
 
     def __len__(self) -> int:
-        return len(self.state)
+        # The len should is used to estimate how large this cache entry is, for
+        # cache eviction purposes. This is why if `self.state` is None it's fine
+        # to return 1.
+
+        return len(self.state) if self.state else 1
 
 
 class StateHandler:
@@ -153,7 +167,7 @@ class StateHandler:
         """
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return ret.state
+        return await ret.get_state(self._state_storage_controller, StateFilter.all())
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: List[str]
@@ -177,7 +191,8 @@ class StateHandler:
 
         logger.debug("calling resolve_state_groups from get_current_users_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        return await self.store.get_joined_users_from_state(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_users_from_state(room_id, state, entry)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
@@ -192,7 +207,8 @@ class StateHandler:
             The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        return await self.store.get_joined_hosts(room_id, entry)
+        state = await entry.get_state(self._state_storage_controller, StateFilter.all())
+        return await self.store.get_joined_hosts(room_id, state, entry)
 
     async def compute_event_context(
         self,
@@ -227,10 +243,19 @@ class StateHandler:
         #
         if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
-            entry = None
+
+            # .. though we need to get a state group for it.
+            state_group_before_event = (
+                await self._state_storage_controller.store_state_group(
+                    event.event_id,
+                    event.room_id,
+                    prev_group=None,
+                    delta_ids=None,
+                    current_state_ids=state_ids_before_event,
+                )
+            )
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -264,36 +289,27 @@ class StateHandler:
                 await_full_state=False,
             )
 
-            state_ids_before_event = entry.state
-            state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
 
-        #
-        # make sure that we have a state group at that point. If it's not a state event,
-        # that will be the state group for the new event. If it *is* a state event,
-        # it might get rejected (in which case we'll need to persist it with the
-        # previous state group)
-        #
-
-        if not state_group_before_event:
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=state_group_before_event_prev_group,
-                    delta_ids=deltas_to_state_group_before_event,
-                    current_state_ids=state_ids_before_event,
+            # We make sure that we have a state group assigned to the state.
+            if entry.state_group is None:
+                state_ids_before_event = await entry.get_state(
+                    self._state_storage_controller, StateFilter.all()
+                )
+                state_group_before_event = (
+                    await self._state_storage_controller.store_state_group(
+                        event.event_id,
+                        event.room_id,
+                        prev_group=state_group_before_event_prev_group,
+                        delta_ids=deltas_to_state_group_before_event,
+                        current_state_ids=state_ids_before_event,
+                    )
                 )
-            )
-
-            # Assign the new state group to the cached state entry.
-            #
-            # Note that this can race in that we could generate multiple state
-            # groups for the same state entry, but that is just inefficient
-            # rather than dangerous.
-            if entry and entry.state_group is None:
                 entry.state_group = state_group_before_event
+            else:
+                state_group_before_event = entry.state_group
+                state_ids_before_event = None
 
         #
         # now if it's not a state event, we're done
@@ -313,6 +329,10 @@ class StateHandler:
         #
         # otherwise, we'll need to create a new state group for after the event
         #
+        if state_ids_before_event is None:
+            state_ids_before_event = await entry.get_state(
+                self._state_storage_controller, StateFilter.all()
+            )
 
         key = (event.type, event.state_key)
         if key in state_ids_before_event:
@@ -372,9 +392,6 @@ class StateHandler:
         state_group_ids_set = set(state_group_ids)
         if len(state_group_ids_set) == 1:
             (state_group_id,) = state_group_ids_set
-            state = await self._state_storage_controller.get_state_for_groups(
-                state_group_ids_set
-            )
             (
                 prev_group,
                 delta_ids,
@@ -382,7 +399,7 @@ class StateHandler:
                 state_group_id
             )
             return _StateCacheEntry(
-                state=state[state_group_id],
+                state=None,
                 state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py
index 55649719f6..45101cda7a 100644
--- a/synapse/storage/controllers/__init__.py
+++ b/synapse/storage/controllers/__init__.py
@@ -43,4 +43,6 @@ class StorageControllers:
 
         self.persistence = None
         if stores.persist_events:
-            self.persistence = EventsPersistenceStorageController(hs, stores)
+            self.persistence = EventsPersistenceStorageController(
+                hs, stores, self.state
+            )
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index ea499ce0f8..af65e5913b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
 from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
     current state and forward extremity changes.
     """
 
-    def __init__(self, hs: "HomeServer", stores: Databases):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        stores: Databases,
+        state_controller: StateStorageController,
+    ):
         # We ultimately want to split out the state store from the main store,
         # so we use separate variables here even though they point to the same
         # store for now.
@@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
             self._process_event_persist_queue_task
         )
         self._state_resolution_handler = hs.get_state_resolution_handler()
+        self._state_controller = state_controller
 
     async def _process_event_persist_queue_task(
         self,
@@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
             state_res_store=StateResolutionStore(self.main_store),
         )
 
-        return res.state
+        return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
         self, _room_id: str, task: _PersistEventsTask
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 0b5e4e4254..71a65d565a 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -31,7 +31,6 @@ import attr
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
@@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_context(
-        self, event: EventBase, context: EventContext
-    ) -> Dict[str, ProfileInfo]:
-        state_group: Union[object, int] = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = await context.get_current_state_ids()
-        assert current_state_ids is not None
-        assert state_group is not None
-        return await self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
-
     async def get_joined_users_from_state(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
-                room_id, state_group, state_entry.state, context=state_entry
+                room_id, state_group, state, context=state_entry
             )
 
-    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+    @cached(num_args=2, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
-        cache_context: _CacheContext,
         event: Optional[EventBase] = None,
-        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
+        context: Optional["_StateCacheEntry"] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -1017,7 +997,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     async def get_joined_hosts(
-        self, room_id: str, state_entry: "_StateCacheEntry"
+        self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
@@ -1030,7 +1010,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry=state_entry
+                room_id, state_group, state, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
@@ -1038,6 +1018,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self,
         room_id: str,
         state_group: Union[object, int],
+        state: StateMap[str],
         state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
@@ -1093,7 +1074,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
                 joined_users = await self.get_joined_users_from_state(
-                    room_id, state_entry
+                    room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1218786d79..240b02cb9f 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -23,7 +23,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import TestHomeServer
-from tests.test_utils import event_injection
 
 
 class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # It now knows about Charlie's server.
         self.assertEqual(self.store._known_servers_count, 2)
 
-    def test_get_joined_users_from_context(self) -> None:
-        room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
-        bob_event = self.get_success(
-            event_injection.inject_member_event(
-                self.hs, room, self.u_bob, Membership.JOIN
-            )
-        )
-
-        # first, create a regular event
-        event, context = self.get_success(
-            event_injection.create_event(
-                self.hs,
-                room_id=room,
-                sender=self.u_alice,
-                prev_event_ids=[bob_event.event_id],
-                type="m.test.1",
-                content={},
-            )
-        )
-
-        users = self.get_success(
-            self.store.get_joined_users_from_context(event, context)
-        )
-        self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
-        # Regression test for #7376: create a state event whose key matches bob's
-        # user_id, but which is *not* a membership event, and persist that; then check
-        # that `get_joined_users_from_context` returns the correct users for the next event.
-        non_member_event = self.get_success(
-            event_injection.inject_event(
-                self.hs,
-                room_id=room,
-                sender=self.u_bob,
-                prev_event_ids=[bob_event.event_id],
-                type="m.test.2",
-                state_key=self.u_bob,
-                content={},
-            )
-        )
-        event, context = self.get_success(
-            event_injection.create_event(
-                self.hs,
-                room_id=room,
-                sender=self.u_alice,
-                prev_event_ids=[non_member_event.event_id],
-                type="m.test.3",
-                content={},
-            )
-        )
-        users = self.get_success(
-            self.store.get_joined_users_from_context(event, context)
-        )
-        self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
-
     def test__null_byte_in_display_name_properly_handled(self) -> None:
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)