From 10280fc9437038f7ef715873e491d54b0a6d2208 Mon Sep 17 00:00:00 2001 From: David Teller Date: Fri, 20 May 2022 14:53:25 +0200 Subject: Uniformize spam-checker API, part 1: the `Code` enum. (#12703) --- changelog.d/12703.misc | 1 + synapse/api/errors.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12703.misc diff --git a/changelog.d/12703.misc b/changelog.d/12703.misc new file mode 100644 index 0000000000..9aaa1bbaa3 --- /dev/null +++ b/changelog.d/12703.misc @@ -0,0 +1 @@ +Convert namespace class `Codes` into a string enum. \ No newline at end of file diff --git a/synapse/api/errors.py b/synapse/api/errors.py index cb3b7323d5..9614be6b4e 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -17,6 +17,7 @@ import logging import typing +from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Union @@ -30,7 +31,11 @@ if typing.TYPE_CHECKING: logger = logging.getLogger(__name__) -class Codes: +class Codes(str, Enum): + """ + All known error codes, as an enum of strings. + """ + UNRECOGNIZED = "M_UNRECOGNIZED" UNAUTHORIZED = "M_UNAUTHORIZED" FORBIDDEN = "M_FORBIDDEN" @@ -265,7 +270,9 @@ class UnrecognizedRequestError(SynapseError): """An error indicating we don't understand the request you're trying to make""" def __init__( - self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED + self, + msg: str = "Unrecognized request", + errcode: str = Codes.UNRECOGNIZED, ): super().__init__(400, msg, errcode) -- cgit 1.4.1 From 39dee30f0120290d6ef3504815655df1a6cf47a5 Mon Sep 17 00:00:00 2001 From: reivilibre Date: Fri, 20 May 2022 15:28:23 +0100 Subject: Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands. (#12809) --- changelog.d/12672.feature | 1 + changelog.d/12672.misc | 1 - changelog.d/12809.feature | 1 + synapse/replication/tcp/commands.py | 12 ++++++++++++ synapse/replication/tcp/redis.py | 6 +++--- 5 files changed, 17 insertions(+), 4 deletions(-) create mode 100644 changelog.d/12672.feature delete mode 100644 changelog.d/12672.misc create mode 100644 changelog.d/12809.feature diff --git a/changelog.d/12672.feature b/changelog.d/12672.feature new file mode 100644 index 0000000000..b989e0d208 --- /dev/null +++ b/changelog.d/12672.feature @@ -0,0 +1 @@ +Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands. \ No newline at end of file diff --git a/changelog.d/12672.misc b/changelog.d/12672.misc deleted file mode 100644 index 265e0a801f..0000000000 --- a/changelog.d/12672.misc +++ /dev/null @@ -1 +0,0 @@ -Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. \ No newline at end of file diff --git a/changelog.d/12809.feature b/changelog.d/12809.feature new file mode 100644 index 0000000000..b989e0d208 --- /dev/null +++ b/changelog.d/12809.feature @@ -0,0 +1 @@ +Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands. \ No newline at end of file diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index fe34948168..32f52e54d8 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta): # by default, we just use the command name. return self.NAME + def redis_channel_name(self, prefix: str) -> str: + """ + Returns the Redis channel name upon which to publish this command. + + Args: + prefix: The prefix for the channel. + """ + return prefix + SC = TypeVar("SC", bound="_SimpleCommand") @@ -395,6 +404,9 @@ class UserIpCommand(Command): f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})" ) + def redis_channel_name(self, prefix: str) -> str: + return f"{prefix}/USER_IP" + class RemoteServerUpCommand(_SimpleCommand): """Sent when a worker has detected that a remote server is no longer diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 73294654ef..fd1c0ec6af 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -221,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol): # remote instances. tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() + channel_name = cmd.redis_channel_name(self.synapse_stream_prefix) + await make_deferred_yieldable( - self.synapse_outbound_redis_connection.publish( - self.synapse_stream_prefix, encoded_string - ) + self.synapse_outbound_redis_connection.publish(channel_name, encoded_string) ) -- cgit 1.4.1 From 2ebb0c6f994732928bc6040976880d74477e5e79 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 21 May 2022 13:58:52 +0100 Subject: Pull out less state when handling gaps --- synapse/handlers/federation_event.py | 140 +++++++++++++++----------------- synapse/handlers/message.py | 17 +++- synapse/state/__init__.py | 14 ++-- synapse/storage/databases/main/state.py | 42 ++++++++++ 4 files changed, 127 insertions(+), 86 deletions(-) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 05c122f224..8521a230cb 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -463,7 +463,9 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in state + }, partial_state=partial_state, ) @@ -501,7 +503,7 @@ class FederationEventHandler: # build a new state group for it if need be context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event=state, ) if context.partial_state: # this can happen if some or all of the event's prev_events still have @@ -765,7 +767,7 @@ class FederationEventHandler: async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: + ) -> Optional[StateMap[str]]: """Calculate the state at an event with missing prev_events. This is used when we have pulled a batch of events from a remote server, and @@ -792,8 +794,8 @@ class FederationEventHandler: event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. + if we already had all the prev events, `None`. Otherwise, returns + the state at `event`. """ room_id = event.room_id event_id = event.event_id @@ -837,13 +839,7 @@ class FederationEventHandler: dest, room_id, p ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } - state_maps.append(remote_state_map) - - for x in remote_state: - event_map[x.event_id] = x + state_maps.append(remote_state) room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( @@ -854,19 +850,6 @@ class FederationEventHandler: state_res_store=StateResolutionStore(self._store), ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self._store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.as_is, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "Error attempting to resolve state at missing prev_events", @@ -878,14 +861,14 @@ class FederationEventHandler: "We can't get valid state history.", affected=event_id, ) - return state + return state_map async def _get_state_after_missing_prev_event( self, destination: str, room_id: str, event_id: str, - ) -> List[EventBase]: + ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -894,7 +877,7 @@ class FederationEventHandler: event_id: The id of the event we want the state at. Returns: - A list of events in the state, including the event itself + The state *after* the given event. """ ( state_event_ids, @@ -913,15 +896,13 @@ class FederationEventHandler: desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self._store.get_events( - desired_events, allow_rejected=True - ) + have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - fetched_events.keys() + missing_desired_events = desired_events - have_events logger.debug( "We are missing %i events (got %i)", len(missing_desired_events), - len(fetched_events), + len(have_events), ) # We probably won't need most of the auth events, so let's just check which @@ -932,7 +913,7 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events = set(auth_event_ids) - have_events missing_auth_events.difference_update( await self._store.have_seen_events(room_id, missing_auth_events) ) @@ -958,47 +939,54 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=missing_events ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self._store.get_events(missing_desired_events, allow_rejected=True) - ) + event_metadata = await self._store.get_metadata_for_events(state_event_ids) # check for events which were in the wrong room. # # this can happen if a remote server claims that the state or # auth_events at an event in room A are actually events in room B - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + event_metadata = await self._store.get_metadata_for_events(state_event_ids) - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) + state_map = {} + + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue + + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue - del fetched_events[bad_event_id] + state_map[(metadata.event_type, metadata.state_key)] = state_event_id # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. - failed_to_fetch = desired_events - fetched_events.keys() + failed_to_fetch = desired_events - event_metadata.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1006,14 +994,12 @@ class FederationEventHandler: failed_to_fetch, ) - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) + state_map[ + (remote_event.type, remote_event.state_key) + ] = remote_event.event_id - return remote_state + return state_map async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str @@ -1040,7 +1026,7 @@ class FederationEventHandler: self, origin: str, event: EventBase, - state: Optional[Iterable[EventBase]], + state: Optional[StateMap[str]], backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1074,7 +1060,7 @@ class FederationEventHandler: try: context = await self._state_handler.compute_event_context( - event, old_state=state + event, state_ids_before_event=state ) context = await self._check_event_auth( origin, @@ -1565,7 +1551,7 @@ class FederationEventHandler: async def _check_for_soft_fail( self, event: EventBase, - state: Optional[Iterable[EventBase]], + state: Optional[StateMap[str]], origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1602,17 +1588,21 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) + state_sets: List[StateMap[str]] = list(state_sets_d.values()) state_sets.append(state) - current_states = await self._state_handler.resolve_events( - room_version, state_sets, event + + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map={}, + state_res_store=StateResolutionStore(self._store), + ) ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } else: current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e566ff1f8e..79e4b1cce3 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1021,8 +1021,21 @@ class EventCreationHandler: # # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. - old_state = await self.store.get_events_as_list(state_event_ids) - context = await self.state.compute_event_context(event, old_state=old_state) + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map = {} + for event_id, data in metadata.items(): + if data.state_key is None: + raise Exception( + "Trying to set non-state event as state: %s", event_id + ) + + state_map[(data.event_type, data.state_key)] = event_id + + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map, + ) else: context = await self.state.compute_event_context(event) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b4ed42cff..712facd3a8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -261,7 +261,7 @@ class StateHandler: async def compute_event_context( self, event: EventBase, - old_state: Optional[Iterable[EventBase]] = None, + state_ids_before_event: Optional[StateMap[str]] = None, partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,12 +273,12 @@ class StateHandler: Args: event: - old_state: The state at the event if it can't be + state_ids_before_event: The state at the event if it can't be calculated from existing events. This is normally only specified when receiving an event from federation where we don't have the prev events for, e.g. when backfilling. - partial_state: True if `old_state` is partial and omits non-critical - membership events + partial_state: True if `state_ids_before_event` is partial and omits + non-critical membership events Returns: The event context. """ @@ -288,11 +288,7 @@ class StateHandler: # # first of all, figure out the state before the event # - if old_state: - # if we're given the state before the event, then we use that - state_ids_before_event: StateMap[str] = { - (s.type, s.state_key): s.event_id for s in old_state - } + if state_ids_before_event: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 18ae8aee29..d939a0e0b6 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -16,6 +16,8 @@ import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -26,6 +28,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore @@ -43,6 +46,15 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: Optional[str] + + def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: @@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version + async def get_metadata_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, EventMetadata]: + """Get some metadata (room_id, type, state_key) for the given events.""" + + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", event_ids + ) + + sql = f""" + SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e + LEFT JOIN state_events USING (event_id) + WHERE {clause} + """ + + def get_metadata_for_events_txn( + txn: LoggingTransaction, + ) -> Dict[str, EventMetadata]: + txn.execute(sql, args) + return { + event_id: EventMetadata( + room_id=room_id, event_type=event_type, state_key=state_key + ) + for event_id, room_id, event_type, state_key in txn + } + + return await self.db_pool.runInteraction( + "get_metadata_for_events", get_metadata_for_events_txn + ) + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. -- cgit 1.4.1 From 8b33331cb57e199fc0907accfcabd18710f72079 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 21 May 2022 14:01:21 +0100 Subject: Newsfile --- changelog.d/12828.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/12828.misc diff --git a/changelog.d/12828.misc b/changelog.d/12828.misc new file mode 100644 index 0000000000..afca32471f --- /dev/null +++ b/changelog.d/12828.misc @@ -0,0 +1 @@ +Pull out less state when handling gaps in room DAG. -- cgit 1.4.1 From f9d470b2da67589a32970aee6abc0a02d9b45b87 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 21 May 2022 14:11:52 +0100 Subject: Fix tests --- tests/handlers/test_federation.py | 4 +++- tests/storage/test_events.py | 43 +++++++++++++++++++++++++-------------- tests/test_state.py | 14 +++++++++++-- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index e95dfdce20..183a6635fa 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # federation handler wanting to backfill the fake event. self.get_success( federation_event_handler._process_received_pdu( - self.OTHER_SERVER_NAME, event, state=current_state + self.OTHER_SERVER_NAME, + event, + state={(e.type, e.state_key): e.event_id for e in current_state}, ) ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index ef5e25873c..aaa3189b16 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def persist_event(self, event, state=None): """Persist the event, with optional state""" context = self.get_success( - self.state.compute_event_context(event, old_state=state) + self.state.compute_event_context(event, state_ids_before_event=state) ) self.get_success(self.persistence.persist_event(event, context)) @@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state(self.room_id)) + self.get_success(self.state.get_current_state_ids(self.room_id)) ) state_before_gap.pop(("m.room.history_visibility", "")) context = self.get_success( self.state.compute_event_context( - remote_event_2, old_state=state_before_gap.values() + remote_event_2, + state_ids_before_event=state_before_gap, ) ) @@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id]) @@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id]) @@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([remote_event_2.event_id, local_message_event_id]) @@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase): RoomVersions.V6, ) - state_before_gap = self.get_success(self.state.get_current_state(self.room_id)) + state_before_gap = self.get_success( + self.state.get_current_state_ids(self.room_id) + ) - self.persist_event(remote_event_2, state=state_before_gap.values()) + self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. self.assert_extremities([local_message_event_id, remote_event_2.event_id]) diff --git a/tests/test_state.py b/tests/test_state.py index c6baea3d76..84694d368d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) @@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase): ] context = yield defer.ensureDeferred( - self.state.compute_event_context(event, old_state=old_state) + self.state.compute_event_context( + event, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in old_state + }, + ) ) prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) -- cgit 1.4.1