From b6baa46db078c3ef9e6c5751bccb8d2e1c5c5402 Mon Sep 17 00:00:00 2001 From: Shay Date: Wed, 12 Oct 2022 11:01:00 -0700 Subject: Fix a bug where the joined hosts for a given event were not being properly cached (#14125) --- synapse/handlers/message.py | 91 +++++++++++++++++++++++---------------------- 1 file changed, 47 insertions(+), 44 deletions(-) (limited to 'synapse/handlers/message.py') diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index da1acea275..4e55ebba0b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1390,7 +1390,7 @@ class EventCreationHandler: extra_users=extra_users, ), run_in_background( - self.cache_joined_hosts_for_event, event, context + self.cache_joined_hosts_for_events, events_and_context ).addErrback( log_failure, "cache_joined_hosts_for_event failed" ), @@ -1491,62 +1491,65 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise - async def cache_joined_hosts_for_event( - self, event: EventBase, context: EventContext + async def cache_joined_hosts_for_events( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Precalculate the joined hosts at the event, when using Redis, so that + """Precalculate the joined hosts at each of the given events, when using Redis, so that external federation senders don't have to recalculate it themselves. """ - if not self._external_cache.is_enabled(): - return - - # If external cache is enabled we should always have this. - assert self._external_cache_joined_hosts_updates is not None + for event, _ in events_and_context: + if not self._external_cache.is_enabled(): + return - # We actually store two mappings, event ID -> prev state group, - # state group -> joined hosts, which is much more space efficient - # than event ID -> joined hosts. - # - # Note: We have to cache event ID -> prev state group, as we don't - # store that in the DB. - # - # Note: We set the state group -> joined hosts cache if it hasn't been - # set for a while, so that the expiry time is reset. + # If external cache is enabled we should always have this. + assert self._external_cache_joined_hosts_updates is not None - state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() - ) + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We set the state group -> joined hosts cache if it hasn't been + # set for a while, so that the expiry time is reset. - if state_entry.state_group: - await self._external_cache.set( - "event_to_prev_state_group", - event.event_id, - state_entry.state_group, - expiry_ms=60 * 60 * 1000, + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() ) - if state_entry.state_group in self._external_cache_joined_hosts_updates: - return + if state_entry.state_group: + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) - state = await state_entry.get_state( - self._storage_controllers.state, StateFilter.all() - ) - with opentracing.start_active_span("get_joined_hosts"): - joined_hosts = await self.store.get_joined_hosts( - event.room_id, state, state_entry + if state_entry.state_group in self._external_cache_joined_hosts_updates: + return + + state = await state_entry.get_state( + self._storage_controllers.state, StateFilter.all() ) + with opentracing.start_active_span("get_joined_hosts"): + joined_hosts = await self.store.get_joined_hosts( + event.room_id, state, state_entry + ) - # Note that the expiry times must be larger than the expiry time in - # _external_cache_joined_hosts_updates. - await self._external_cache.set( - "get_joined_hosts", - str(state_entry.state_group), - list(joined_hosts), - expiry_ms=60 * 60 * 1000, - ) + # Note that the expiry times must be larger than the expiry time in + # _external_cache_joined_hosts_updates. + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) - self._external_cache_joined_hosts_updates[state_entry.state_group] = None + self._external_cache_joined_hosts_updates[ + state_entry.state_group + ] = None async def _validate_canonical_alias( self, -- cgit 1.5.1 From 847e2393f3198b88809c9b99de5c681efbf1c92e Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 18 Oct 2022 09:58:47 -0700 Subject: Prepatory work for adding power level event to batched events (#14214) --- changelog.d/14214.misc | 1 + synapse/event_auth.py | 19 ++++++++++++++++++- synapse/handlers/event_auth.py | 18 +++++++++++++----- synapse/handlers/federation.py | 12 +++++------- synapse/handlers/message.py | 10 +++++++++- synapse/handlers/room.py | 4 +--- 6 files changed, 47 insertions(+), 17 deletions(-) create mode 100644 changelog.d/14214.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14214.misc b/changelog.d/14214.misc new file mode 100644 index 0000000000..102928b575 --- /dev/null +++ b/changelog.d/14214.misc @@ -0,0 +1 @@ +When authenticating batched events, check for auth events in batch as well as DB. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c7d5ef92fc..bab31e33c5 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -15,7 +15,18 @@ import logging import typing -from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -134,6 +145,7 @@ def validate_event_for_room_version(event: "EventBase") -> None: async def check_state_independent_auth_rules( store: _EventSourceStore, event: "EventBase", + batched_auth_events: Optional[Mapping[str, "EventBase"]] = None, ) -> None: """Check that an event complies with auth rules that are independent of room state @@ -143,6 +155,8 @@ async def check_state_independent_auth_rules( Args: store: the datastore; used to fetch the auth events for validation event: the event being checked. + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event Raises: AuthError if the checks fail @@ -162,6 +176,9 @@ async def check_state_independent_auth_rules( redact_behaviour=EventRedactBehaviour.as_is, allow_rejected=True, ) + if batched_auth_events: + auth_events.update(batched_auth_events) + room_id = event.room_id auth_dict: MutableStateMap[str] = {} expected_auth_types = auth_types_for_event(event.room_version, event) diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index 8249ca1ed2..3bbad0271b 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, List, Optional, Union +from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( @@ -29,7 +29,6 @@ from synapse.event_auth import ( ) from synapse.events import EventBase from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext from synapse.types import StateMap, get_domain_from_id if TYPE_CHECKING: @@ -51,12 +50,21 @@ class EventAuthHandler: async def check_auth_rules_from_context( self, event: EventBase, - context: EventContext, + batched_auth_events: Optional[Mapping[str, EventBase]] = None, ) -> None: - """Check an event passes the auth rules at its own auth events""" - await check_state_independent_auth_rules(self._store, event) + """Check an event passes the auth rules at its own auth events + Args: + event: event to be authed + batched_auth_events: if the event being authed is part of a batch, any events + from the same batch that may be necessary to auth the current event + """ + await check_state_independent_auth_rules( + self._store, event, batched_auth_events + ) auth_event_ids = event.auth_event_ids() auth_events_by_id = await self._store.get_events(auth_event_ids) + if batched_auth_events: + auth_events_by_id.update(batched_auth_events) check_state_dependent_auth_rules(event, auth_events_by_id.values()) def compute_auth_events( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index ccc045d36f..275a37a575 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -942,7 +942,7 @@ class FederationHandler: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) return event async def on_invite_request( @@ -1123,7 +1123,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new leave %r because %s", event, e) raise e @@ -1182,7 +1182,7 @@ class FederationHandler: try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_knock_request` - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Failed to create new knock %r because %s", event, e) raise e @@ -1348,9 +1348,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context( - event, context - ) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying new third party invite %r because %s", event, e) raise e @@ -1400,7 +1398,7 @@ class FederationHandler: try: validate_event_for_room_version(event) - await self._event_auth_handler.check_auth_rules_from_context(event, context) + await self._event_auth_handler.check_auth_rules_from_context(event) except AuthError as e: logger.warning("Denying third party invite %r because %s", event, e) raise e diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4e55ebba0b..15b828dd74 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1360,8 +1360,16 @@ class EventCreationHandler: else: try: validate_event_for_room_version(event) + # If we are persisting a batch of events the event(s) needed to auth the + # current event may be part of the batch and will not be in the DB yet + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + batched_auth_events = {} + for event_id in event.auth_event_ids(): + auth_event = event_id_to_event.get(event_id) + if auth_event: + batched_auth_events[event_id] = auth_event await self._event_auth_handler.check_auth_rules_from_context( - event, context + event, batched_auth_events ) except AuthError as err: logger.warning("Denying new event %r because %s", event, err) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4e1aacb408..638f54051a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -229,9 +229,7 @@ class RoomCreationHandler: }, ) validate_event_for_room_version(tombstone_event) - await self._event_auth_handler.check_auth_rules_from_context( - tombstone_event, tombstone_context - ) + await self._event_auth_handler.check_auth_rules_from_context(tombstone_event) # Upgrade the room # -- cgit 1.5.1 From b7a7ff6ee39da4981dcfdce61bf8ac4735e3d047 Mon Sep 17 00:00:00 2001 From: Shay Date: Fri, 21 Oct 2022 10:46:22 -0700 Subject: Add initial power level event to batch of bulk persisted events when creating a new room. (#14228) --- changelog.d/14228.misc | 1 + synapse/handlers/federation.py | 4 +- synapse/handlers/federation_event.py | 4 +- synapse/handlers/message.py | 14 ++---- synapse/handlers/room.py | 39 ++++----------- synapse/push/bulk_push_rule_evaluator.py | 74 ++++++++++++++++++++++++----- tests/push/test_bulk_push_rule_evaluator.py | 2 +- tests/replication/_base.py | 2 +- 8 files changed, 82 insertions(+), 58 deletions(-) create mode 100644 changelog.d/14228.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14228.misc b/changelog.d/14228.misc new file mode 100644 index 0000000000..14fe31a8bc --- /dev/null +++ b/changelog.d/14228.misc @@ -0,0 +1 @@ +Add initial power level event to batch of bulk persisted events when creating a new room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 275a37a575..4fbc79a6cb 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1017,7 +1017,9 @@ class FederationHandler: context = EventContext.for_outlier(self._storage_controllers) - await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] + ) try: await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 06e41b5cc0..7da6316a82 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2171,8 +2171,8 @@ class FederationEventHandler: min_depth, ) else: - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context + await self._bulk_push_rule_evaluator.action_for_events_by_user( + [(event, context)] ) try: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 15b828dd74..468900a07f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1433,17 +1433,9 @@ class EventCreationHandler: a room that has been un-partial stated. """ - for event, context in events_and_context: - # Skip push notification actions for historical messages - # because we don't want to notify people about old history back in time. - # The historical messages also do not have the proper `context.current_state_ids` - # and `state_groups` because they have `prev_events` that aren't persisted yet - # (historical messages persisted in reverse-chronological order). - if not event.internal_metadata.is_historical(): - with opentracing.start_active_span("calculate_push_actions"): - await self._bulk_push_rule_evaluator.action_for_event_by_user( - event, context - ) + await self._bulk_push_rule_evaluator.action_for_events_by_user( + events_and_context + ) try: # If we're a worker we need to hit out to the master. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 638f54051a..cc1e5c8f97 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1055,9 +1055,6 @@ class RoomCreationHandler: event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} depth = 1 - # the last event sent/persisted to the db - last_sent_event_id: Optional[str] = None - # the most recently created event prev_event: List[str] = [] # a map of event types, state keys -> event_ids. We collect these mappings this as events are @@ -1102,26 +1099,6 @@ class RoomCreationHandler: return new_event, new_context - async def send( - event: EventBase, - context: synapse.events.snapshot.EventContext, - creator: Requester, - ) -> int: - nonlocal last_sent_event_id - - ev = await self.event_creation_handler.handle_new_client_event( - requester=creator, - events_and_context=[(event, context)], - ratelimit=False, - ignore_shadow_ban=True, - ) - - last_sent_event_id = ev.event_id - - # we know it was persisted, so must have a stream ordering - assert ev.internal_metadata.stream_ordering - return ev.internal_metadata.stream_ordering - try: config = self._presets_dict[preset_config] except KeyError: @@ -1135,10 +1112,14 @@ class RoomCreationHandler: ) logger.debug("Sending %s in new room", EventTypes.Member) - await send(creation_event, creation_context, creator) + ev = await self.event_creation_handler.handle_new_client_event( + requester=creator, + events_and_context=[(creation_event, creation_context)], + ratelimit=False, + ignore_shadow_ban=True, + ) + last_sent_event_id = ev.event_id - # Room create event must exist at this point - assert last_sent_event_id is not None member_event_id, _ = await self.room_member_handler.update_membership( creator, creator.user, @@ -1157,6 +1138,7 @@ class RoomCreationHandler: depth += 1 state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id + events_to_send = [] # We treat the power levels override specially as this needs to be one # of the first events that get sent into a room. pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) @@ -1165,7 +1147,7 @@ class RoomCreationHandler: EventTypes.PowerLevels, pl_content, False ) current_state_group = power_context._state_group - await send(power_event, power_context, creator) + events_to_send.append((power_event, power_context)) else: power_level_content: JsonDict = { "users": {creator_id: 100}, @@ -1214,9 +1196,8 @@ class RoomCreationHandler: False, ) current_state_group = pl_context._state_group - await send(pl_event, pl_context, creator) + events_to_send.append((pl_event, pl_context)) - events_to_send = [] if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: room_alias_event, room_alias_context = await create_event( EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a75386f6a0..d7795a9080 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -165,8 +165,21 @@ class BulkPushRuleEvaluator: return rules_by_user async def _get_power_levels_and_sender_level( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], ) -> Tuple[dict, Optional[int]]: + """ + Given an event and an event context, get the power level event relevant to the event + and the power level of the sender of the event. + Args: + event: event to check + context: context of event to check + event_id_to_event: a mapping of event_id to event for a set of events being + batch persisted. This is needed as the sought-after power level event may + be in this batch rather than the DB + """ # There are no power levels and sender levels possible to get from outlier if event.internal_metadata.is_outlier(): return {}, None @@ -177,15 +190,26 @@ class BulkPushRuleEvaluator: ) pl_event_id = prev_state_ids.get(POWER_KEY) + # fastpath: if there's a power level event, that's all we need, and + # not having a power level event is an extreme edge case if pl_event_id: - # fastpath: if there's a power level event, that's all we need, and - # not having a power level event is an extreme edge case - auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} + # Get the power level event from the batch, or fall back to the database. + pl_event = event_id_to_event.get(pl_event_id) + if pl_event: + auth_events = {POWER_KEY: pl_event} + else: + auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)} else: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=False ) auth_events_dict = await self.store.get_events(auth_events_ids) + # Some needed auth events might be in the batch, combine them with those + # fetched from the database. + for auth_event_id in auth_events_ids: + auth_event = event_id_to_event.get(auth_event_id) + if auth_event: + auth_events_dict[auth_event_id] = auth_event auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -194,16 +218,38 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - @measure_func("action_for_event_by_user") - async def action_for_event_by_user( - self, event: EventBase, context: EventContext + async def action_for_events_by_user( + self, events_and_context: List[Tuple[EventBase, EventContext]] ) -> None: - """Given an event and context, evaluate the push rules, check if the message - should increment the unread count, and insert the results into the - event_push_actions_staging table. + """Given a list of events and their associated contexts, evaluate the push rules + for each event, check if the message should increment the unread count, and + insert the results into the event_push_actions_staging table. """ - if not event.internal_metadata.is_notifiable(): - # Push rules for events that aren't notifiable can't be processed by this + # For batched events the power level events may not have been persisted yet, + # so we pass in the batched events. Thus if the event cannot be found in the + # database we can check in the batch. + event_id_to_event = {e.event_id: e for e, _ in events_and_context} + for event, context in events_and_context: + await self._action_for_event_by_user(event, context, event_id_to_event) + + @measure_func("action_for_event_by_user") + async def _action_for_event_by_user( + self, + event: EventBase, + context: EventContext, + event_id_to_event: Mapping[str, EventBase], + ) -> None: + + if ( + not event.internal_metadata.is_notifiable() + or event.internal_metadata.is_historical() + ): + # Push rules for events that aren't notifiable can't be processed by this and + # we want to skip push notification actions for historical messages + # because we don't want to notify people about old history back in time. + # The historical messages also do not have the proper `context.current_state_ids` + # and `state_groups` because they have `prev_events` that aren't persisted yet + # (historical messages persisted in reverse-chronological order). return # Disable counting as unread unless the experimental configuration is @@ -223,7 +269,9 @@ class BulkPushRuleEvaluator: ( power_levels, sender_power_level, - ) = await self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level( + event, context, event_id_to_event + ) # Find the event's thread ID. relation = relation_from_event(event) diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 675d7df2ac..594e7937a8 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): bulk_evaluator = BulkPushRuleEvaluator(self.hs) # should not raise - self.get_success(bulk_evaluator.action_for_event_by_user(event, context)) + self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index ce53f808db..121f3d8d65 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) -- cgit 1.5.1 From 86c5a710d8b4212f8a8a668d7d4a79c0bb371508 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Thu, 3 Nov 2022 16:21:31 +0000 Subject: Implement MSC3912: Relation-based redactions (#14260) Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/14260.feature | 1 + synapse/api/constants.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/message.py | 47 ++++- synapse/handlers/relations.py | 56 +++++- synapse/rest/client/room.py | 57 ++++-- synapse/rest/client/versions.py | 2 + synapse/storage/databases/main/relations.py | 36 ++++ tests/rest/client/test_redactions.py | 273 +++++++++++++++++++++++++++- tests/rest/client/utils.py | 37 ++++ 10 files changed, 486 insertions(+), 28 deletions(-) create mode 100644 changelog.d/14260.feature (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14260.feature b/changelog.d/14260.feature new file mode 100644 index 0000000000..102dc7b3e0 --- /dev/null +++ b/changelog.d/14260.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 44c5ffc6a5..bc04a0755b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -125,6 +125,8 @@ class EventTypes: MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_MARKER: Final = "org.matrix.msc2716.marker" + Reaction: Final = "m.reaction" + class ToDeviceEventTypes: RoomKeyRequest: Final = "m.room_key_request" diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d9bdd66d55..d4b71d1673 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -128,3 +128,6 @@ class ExperimentalConfig(Config): self.msc3886_endpoint: Optional[str] = experimental.get( "msc3886_endpoint", None ) + + # MSC3912: Relation-based redactions. + self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 468900a07f..4cf593cfdc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -877,6 +877,36 @@ class EventCreationHandler: return prev_event return None + async def get_event_from_transaction( + self, + requester: Requester, + txn_id: str, + room_id: str, + ) -> Optional[EventBase]: + """For the given transaction ID and room ID, check if there is a matching event. + If so, fetch it and return it. + + Args: + requester: The requester making the request in the context of which we want + to fetch the event. + txn_id: The transaction ID. + room_id: The room ID. + + Returns: + An event if one could be found, None otherwise. + """ + if requester.access_token_id: + existing_event_id = await self.store.get_event_id_from_transaction_id( + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, + ) + if existing_event_id: + return await self.store.get_event(existing_event_id) + + return None + async def create_and_send_nonmember_event( self, requester: Requester, @@ -956,18 +986,17 @@ class EventCreationHandler: # extremities to pile up, which in turn leads to state resolution # taking longer. async with self.limiter.queue(event_dict["room_id"]): - if txn_id and requester.access_token_id: - existing_event_id = await self.store.get_event_id_from_transaction_id( - event_dict["room_id"], - requester.user.to_string(), - requester.access_token_id, - txn_id, + if txn_id: + event = await self.get_event_from_transaction( + requester, txn_id, event_dict["room_id"] ) - if existing_event_id: - event = await self.store.get_event(existing_event_id) + if event: # we know it was persisted, so must have a stream ordering assert event.internal_metadata.stream_ordering - return event, event.internal_metadata.stream_ordering + return ( + event, + event.internal_metadata.stream_ordering, + ) event, context = await self.create_event( requester, diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 0a0c6d938e..8e71dda970 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup import attr -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace @@ -75,6 +75,7 @@ class RelationsHandler: self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._event_creation_handler = hs.get_event_creation_handler() async def get_relations( self, @@ -205,6 +206,59 @@ class RelationsHandler: return related_events, next_token + async def redact_events_related_to( + self, + requester: Requester, + event_id: str, + initial_redaction_event: EventBase, + relation_types: List[str], + ) -> None: + """Redacts all events related to the given event ID with one of the given + relation types. + + This method is expected to be called when redacting the event referred to by + the given event ID. + + If an event cannot be redacted (e.g. because of insufficient permissions), log + the error and try to redact the next one. + + Args: + requester: The requester to redact events on behalf of. + event_id: The event IDs to look and redact relations of. + initial_redaction_event: The redaction for the event referred to by + event_id. + relation_types: The types of relations to look for. + + Raises: + ShadowBanError if the requester is shadow-banned + """ + related_event_ids = ( + await self._main_store.get_all_relations_for_event_with_types( + event_id, relation_types + ) + ) + + for related_event_id in related_event_ids: + try: + await self._event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": initial_redaction_event.content, + "room_id": initial_redaction_event.room_id, + "sender": requester.user.to_string(), + "redacts": related_event_id, + }, + ratelimit=False, + ) + except SynapseError as e: + logger.warning( + "Failed to redact event %s (related to event %s): %s", + related_event_id, + event_id, + e.msg, + ) + async def get_annotations_for_event( self, event_id: str, diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 01e5079963..91cb791139 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -52,6 +52,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.state import StateFilter @@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() + self._relation_handler = hs.get_relations_handler() + self._msc3912_enabled = hs.config.experimental.msc3912_enabled def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P[^/]*)/redact/(?P[^/]*)" @@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet): content = parse_json_object_from_request(request) try: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + with_relations = None + if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content: + with_relations = content["org.matrix.msc3912.with_relations"] + del content["org.matrix.msc3912.with_relations"] + + # Check if there's an existing event for this transaction now (even though + # create_and_send_nonmember_event also does it) because, if there's one, + # then we want to skip the call to redact_events_related_to. + event = None + if txn_id: + event = await self.event_creation_handler.get_event_from_transaction( + requester, txn_id, room_id + ) + + if event is None: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + + if with_relations: + run_as_background_process( + "redact_related_events", + self._relation_handler.redact_events_related_to, + requester=requester, + event_id=event_id, + initial_redaction_event=event, + relation_types=with_relations, + ) + event_id = event.event_id except ShadowBanError: event_id = "$" + random_string(43) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 9b1b72c68a..180a11ef88 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet): # Adds support for simple HTTP rendezvous as per MSC3886 "org.matrix.msc3886": self.config.experimental.msc3886_endpoint is not None, + # Adds support for relation-based redactions as per MSC3912. + "org.matrix.msc3912": self.config.experimental.msc3912_enabled, }, }, ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c022510e76..ca431002c8 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore): "get_recent_references_for_event", _get_recent_references_for_event_txn ) + async def get_all_relations_for_event_with_types( + self, + event_id: str, + relation_types: List[str], + ) -> List[str]: + """Get the event IDs of all events that have a relation to the given event with + one of the given relation types. + + Args: + event_id: The event for which to look for related events. + relation_types: The types of relations to look for. + + Returns: + A list of the IDs of the events that relate to the given event with one of + the given relation types. + """ + + def get_all_relation_ids_for_event_with_types_txn( + txn: LoggingTransaction, + ) -> List[str]: + rows = self.db_pool.simple_select_many_txn( + txn=txn, + table="event_relations", + column="relation_type", + iterable=relation_types, + keyvalues={"relates_to_id": event_id}, + retcols=["event_id"], + ) + + return [row["event_id"] for row in rows] + + return await self.db_pool.runInteraction( + desc="get_all_relation_ids_for_event_with_types", + func=get_all_relation_ids_for_event_with_types_txn, + ) + async def event_includes_relation(self, event_id: str) -> bool: """Check if the given event relates to another event. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index be4c67d68e..5dfe44defb 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from twisted.test.proto_helpers import MemoryReactor +from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RedactionsTestCase(HomeserverTestCase): @@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase): ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = 200, + with_relations: Optional[List[str]] = None, ) -> JsonDict: """Helper function to send a redaction event. @@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase): """ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) - channel = self.make_request("POST", path, content={}, access_token=access_token) + request_content = {} + if with_relations: + request_content["org.matrix.msc3912.with_relations"] = with_relations + + channel = self.make_request( + "POST", path, request_content, access_token=access_token + ) self.assertEqual(channel.code, expect_code) return channel.json_body @@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase): # These should all succeed, even though this would be denied by # the standard message ratelimiter self._redact_event(self.mod_access_token, self.room_id, msg_id) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations(self) -> None: + """Tests that we can redact the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete events that + # relate to it with m.replace. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=[ + RelationTypes.REPLACE, + RelationTypes.THREAD, + ], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction did not get redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertNotIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_no_perms(self) -> None: + """Tests that, when redacting a message along with its relations, if not all + the related messages can be redacted because of insufficient permissions, the + server still redacts all the ones that can be. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.other_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message, this one from the moderator. We do this for the + # first message with the m.thread relation (and not the last one) to ensure + # that, when the server fails to redact it, it doesn't stop there, and it + # instead goes on to redact the other one. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + first_threaded_event_id = res["event_id"] + + # Send a second threaded message, this time from the user who'll perform the + # redaction. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 2", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.other_access_token, + ) + second_threaded_event_id = res["event_id"] + + # Redact the thread's root, and request that all threaded messages are also + # redacted. Send that request from the non-mod user, so that the first threaded + # event cannot be redacted. + self._redact_event( + self.other_access_token, + self.room_id, + root_event_id, + with_relations=[RelationTypes.THREAD], + ) + + # Check that the thread root got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the last message in the thread got redacted, despite failing to + # redact the one before it. + event_dict = self.helper.get_event( + self.room_id, second_threaded_event_id, self.other_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the message that was sent into the tread by the mod user is not + # redacted. + event_dict = self.helper.get_event( + self.room_id, first_threaded_event_id, self.other_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("message 1", event_dict["content"]["body"]) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_relations_txn_id_reuse(self) -> None: + """Tests that redacting a message using a transaction ID, then reusing the same + transaction ID but providing an additional list of relations to redact, is + effectively a no-op. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "root", + }, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send a first threaded message. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "I'm in a thread!", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Send a first redaction request which redacts only the root event. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Send a second redaction request which redacts the root event as well as + # threaded messages. + channel = self.make_request( + method="PUT", + path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo", + content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]}, + access_token=self.mod_access_token, + ) + self.assertEqual(channel.code, 200) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict) + + # Check that the threaded message didn't get redacted (since that wasn't part of + # the original redaction). + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("body", event_dict["content"], event_dict) + self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 706399fae5..8d6f2b6ff9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -410,6 +410,43 @@ class RestHelper: return channel.json_body + def get_event( + self, + room_id: str, + event_id: str, + tok: Optional[str] = None, + expect_code: int = HTTPStatus.OK, + ) -> JsonDict: + """Request a specific event from the server. + + Args: + room_id: the room in which the event was sent. + event_id: the event's ID. + tok: the token to request the event with. + expect_code: the expected HTTP status for the response. + + Returns: + The event as a dict. + """ + path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}" + if tok: + path = path + f"?access_token={tok}" + + channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + path, + ) + + assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( + expect_code, + channel.code, + channel.result["body"], + ) + + return channel.json_body + def _read_write_state( self, room_id: str, -- cgit 1.5.1 From 72f3e381375ba10d576a23025ca312397114de6b Mon Sep 17 00:00:00 2001 From: Shay Date: Mon, 28 Nov 2022 19:18:12 -0800 Subject: Fix possible variable shadow in `create_new_client_event` (#14575) --- changelog.d/14575.misc | 1 + synapse/handlers/message.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 changelog.d/14575.misc (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14575.misc b/changelog.d/14575.misc new file mode 100644 index 0000000000..f6fa54eaa2 --- /dev/null +++ b/changelog.d/14575.misc @@ -0,0 +1 @@ +Fix a possible variable shadow in `create_new_client_event`. \ No newline at end of file diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4cf593cfdc..5cbe89f4fd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1135,11 +1135,13 @@ class EventCreationHandler: ) state_events = await self.store.get_events_as_list(state_event_ids) # Create a StateMap[str] - state_map = {(e.type, e.state_key): e.event_id for e in state_events} + current_state_ids = { + (e.type, e.state_key): e.event_id for e in state_events + } # Actually strip down and only use the necessary auth events auth_event_ids = self._event_auth_handler.compute_auth_events( event=temp_event, - current_state_ids=state_map, + current_state_ids=current_state_ids, for_verification=False, ) -- cgit 1.5.1 From b5b5f6608462a988b05502a3b70b6a57ca3846d2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:19:30 +0000 Subject: Move `StateFilter` to `synapse.types` (#14668) * Move `StateFilter` to `synapse.types` * Changelog --- changelog.d/14668.misc | 1 + synapse/events/builder.py | 2 +- synapse/events/snapshot.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 2 +- synapse/rest/client/room.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/controllers/persist_events.py | 2 +- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/state.py | 2 +- synapse/storage/databases/state/bg_updates.py | 2 +- synapse/storage/databases/state/store.py | 2 +- synapse/storage/state.py | 567 ---------------- synapse/types.py | 928 -------------------------- synapse/types/__init__.py | 928 ++++++++++++++++++++++++++ synapse/types/state.py | 567 ++++++++++++++++ synapse/visibility.py | 2 +- tests/storage/test_state.py | 2 +- 29 files changed, 1520 insertions(+), 1519 deletions(-) create mode 100644 changelog.d/14668.misc delete mode 100644 synapse/storage/state.py delete mode 100644 synapse/types.py create mode 100644 synapse/types/__init__.py create mode 100644 synapse/types/state.py (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14668.misc b/changelog.d/14668.misc new file mode 100644 index 0000000000..5269d8a97d --- /dev/null +++ b/changelog.d/14668.misc @@ -0,0 +1 @@ +Move `StateFilter` to `synapse.types`. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index d62906043f..94dd1298e1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -28,8 +28,8 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import EventID, JsonDict +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.stringutils import random_string diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 1c0e96bec7..6eaef8b57a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore - from synapse.storage.state import StateFilter + from synapse.types.state import StateFilter @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3398fcaf7d..b2784d7333 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -70,8 +70,8 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f7223b03c3..d2facdab60 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -75,7 +75,6 @@ from synapse.replication.http.federation import ( from synapse.state import StateResolutionStore from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, @@ -83,6 +82,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5cbe89f4fd..d6e90ef259 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,7 +59,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, PersistedEventPosition, @@ -70,6 +69,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index c572508a02..8c8ff18a1a 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6307fa9c5d..c611efb760 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -46,8 +46,8 @@ from synapse.replication.http.register import ( ReplicationRegisterServlet, ) from synapse.spam_checker_api import RegistrationBehaviour -from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester +from synapse.types.state import StateFilter if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6dcfd86fdf..f81241c2b3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -62,7 +62,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -77,6 +76,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import stringutils from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6ad2b38b8f..0c39e852a1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -34,7 +34,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, Requester, @@ -45,6 +44,7 @@ from synapse.types import ( create_requester, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index bcab98c6d5..33115ce488 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.storage.state import StateFilter from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client if TYPE_CHECKING: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dace9b606f..7d6a653747 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -49,7 +49,6 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.storage.state import StateFilter from synapse.types import ( DeviceListUpdates, JsonDict, @@ -61,6 +60,7 @@ from synapse.types import ( StreamToken, UserID, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 96a661177a..0092a03c59 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,7 +111,6 @@ from synapse.storage.background_updates import ( ) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.types import ( DomainSpecificString, JsonDict, @@ -124,6 +123,7 @@ from synapse.types import ( UserProfile, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import CachedFunction, cached diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9ed35d8461..36e5b327ef 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -35,8 +35,8 @@ from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership -from synapse.storage.state import StateFilter from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index c2575ba3d9..93b255ced5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -37,8 +37,8 @@ from synapse.push.push_types import ( TemplateVars, ) from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 747e6fda83..e957aa28ca 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,9 +34,9 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, RoomID, UserID, create_requester +from synapse.types.state import StateFilter from synapse.util import json_decoder if TYPE_CHECKING: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 514eb6afc8..790614d721 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -55,9 +55,9 @@ from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name, random_string diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 833ffec3de..ee5469d5a8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,8 +44,8 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import StateMap +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 33ffef521b..f1d2c71c91 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -58,13 +58,13 @@ from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2b31ce54bb..26d79c6e62 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -31,12 +31,12 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.cancellation import cancellable if TYPE_CHECKING: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index af7bebee80..c801a93b5b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -33,8 +33,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 4a4ad0f492..d743282f13 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -22,8 +22,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.state import StateFilter from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string if TYPE_CHECKING: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f8cfcaca83..1a7232b276 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -25,10 +25,10 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.types.state import StateFilter from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.cancellation import cancellable diff --git a/synapse/storage/state.py b/synapse/storage/state.py deleted file mode 100644 index 0004d955b4..0000000000 --- a/synapse/storage/state.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import ( - TYPE_CHECKING, - Callable, - Collection, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - TypeVar, -) - -import attr -from frozendict import frozendict - -from synapse.api.constants import EventTypes -from synapse.types import MutableStateMap, StateKey, StateMap - -if TYPE_CHECKING: - from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - - -logger = logging.getLogger(__name__) - -# Used for generic functions below -T = TypeVar("T") - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StateFilter: - """A filter used when querying for state. - - Attributes: - types: Map from type to set of state keys (or None). This specifies - which state_keys for the given type to fetch from the DB. If None - then all events with that type are fetched. If the set is empty - then no events with that type are fetched. - include_others: Whether to fetch events with types that do not - appear in `types`. - """ - - types: "frozendict[str, Optional[FrozenSet[str]]]" - include_others: bool = False - - def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary - if self.include_others: - # this is needed to work around the fact that StateFilter is frozen - object.__setattr__( - self, - "types", - frozendict({k: v for k, v in self.types.items() if v is not None}), - ) - - @staticmethod - def all() -> "StateFilter": - """Returns a filter that fetches everything. - - Returns: - The state filter. - """ - return _ALL_STATE_FILTER - - @staticmethod - def none() -> "StateFilter": - """Returns a filter that fetches nothing. - - Returns: - The new state filter. - """ - return _NONE_STATE_FILTER - - @staticmethod - def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": - """Creates a filter that only fetches the given types - - Args: - types: A list of type and state keys to fetch. A state_key of None - fetches everything for that type - - Returns: - The new state filter. - """ - type_dict: Dict[str, Optional[Set[str]]] = {} - for typ, s in types: - if typ in type_dict: - if type_dict[typ] is None: - continue - - if s is None: - type_dict[typ] = None - continue - - type_dict.setdefault(typ, set()).add(s) # type: ignore - - return StateFilter( - types=frozendict( - (k, frozenset(v) if v is not None else None) - for k, v in type_dict.items() - ) - ) - - @staticmethod - def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": - """Creates a filter that returns all non-member events, plus the member - events for the given users - - Args: - members: Set of user IDs - - Returns: - The new state filter - """ - return StateFilter( - types=frozendict({EventTypes.Member: frozenset(members)}), - include_others=True, - ) - - @staticmethod - def freeze( - types: Mapping[str, Optional[Collection[str]]], include_others: bool - ) -> "StateFilter": - """ - Returns a (frozen) StateFilter with the same contents as the parameters - specified here, which can be made of mutable types. - """ - types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} - for state_types, state_keys in types.items(): - if state_keys is not None: - types_with_frozen_values[state_types] = frozenset(state_keys) - else: - types_with_frozen_values[state_types] = None - - return StateFilter( - frozendict(types_with_frozen_values), include_others=include_others - ) - - def return_expanded(self) -> "StateFilter": - """Creates a new StateFilter where type wild cards have been removed - (except for memberships). The returned filter is a superset of the - current one, i.e. anything that passes the current filter will pass - the returned filter. - - This helps the caching as the DictionaryCache knows if it has *all* the - state, but does not know if it has all of the keys of a particular type, - which makes wildcard lookups expensive unless we have a complete cache. - Hence, if we are doing a wildcard lookup, populate the cache fully so - that we can do an efficient lookup next time. - - Note that since we have two caches, one for membership events and one for - other events, we can be a bit more clever than simply returning - `StateFilter.all()` if `has_wildcards()` is True. - - We return a StateFilter where: - 1. the list of membership events to return is the same - 2. if there is a wildcard that matches non-member events we - return all non-member events - - Returns: - The new state filter. - """ - - if self.is_full(): - # If we're going to return everything then there's nothing to do - return self - - if not self.has_wildcards(): - # If there are no wild cards, there's nothing to do - return self - - if EventTypes.Member in self.types: - get_all_members = self.types[EventTypes.Member] is None - else: - get_all_members = self.include_others - - has_non_member_wildcard = self.include_others or any( - state_keys is None - for t, state_keys in self.types.items() - if t != EventTypes.Member - ) - - if not has_non_member_wildcard: - # If there are no non-member wild cards we can just return ourselves - return self - - if get_all_members: - # We want to return everything. - return StateFilter.all() - elif EventTypes.Member in self.types: - # We want to return all non-members, but only particular - # memberships - return StateFilter( - types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), - include_others=True, - ) - else: - # We want to return all non-members - return _ALL_NON_MEMBER_STATE_FILTER - - def make_sql_filter_clause(self) -> Tuple[str, List[str]]: - """Converts the filter to an SQL clause. - - For example: - - f = StateFilter.from_types([("m.room.create", "")]) - clause, args = f.make_sql_filter_clause() - clause == "(type = ? AND state_key = ?)" - args == ['m.room.create', ''] - - - Returns: - The SQL string (may be empty) and arguments. An empty SQL string is - returned when the filter matches everything (i.e. is "full"). - """ - - where_clause = "" - where_args: List[str] = [] - - if self.is_full(): - return where_clause, where_args - - if not self.include_others and not self.types: - # i.e. this is an empty filter, so we need to return a clause that - # will match nothing - return "1 = 2", [] - - # First we build up a lost of clauses for each type/state_key combo - clauses = [] - for etype, state_keys in self.types.items(): - if state_keys is None: - clauses.append("(type = ?)") - where_args.append(etype) - continue - - for state_key in state_keys: - clauses.append("(type = ? AND state_key = ?)") - where_args.extend((etype, state_key)) - - # This will match anything that appears in `self.types` - where_clause = " OR ".join(clauses) - - # If we want to include stuff that's not in the types dict then we add - # a `OR type NOT IN (...)` clause to the end. - if self.include_others: - if where_clause: - where_clause += " OR " - - where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) - where_args.extend(self.types) - - return where_clause, where_args - - def max_entries_returned(self) -> Optional[int]: - """Returns the maximum number of entries this filter will return if - known, otherwise returns None. - - For example a simple state filter asking for `("m.room.create", "")` - will return 1, whereas the default state filter will return None. - - This is used to bail out early if the right number of entries have been - fetched. - """ - if self.has_wildcards(): - return None - - return len(self.concrete_types()) - - def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: - """Returns the state filtered with by this StateFilter. - - Args: - state: The state map to filter - - Returns: - The filtered state map. - This is a copy, so it's safe to mutate. - """ - if self.is_full(): - return dict(state_dict) - - filtered_state = {} - for k, v in state_dict.items(): - typ, state_key = k - if typ in self.types: - state_keys = self.types[typ] - if state_keys is None or state_key in state_keys: - filtered_state[k] = v - elif self.include_others: - filtered_state[k] = v - - return filtered_state - - def is_full(self) -> bool: - """Whether this filter fetches everything or not - - Returns: - True if the filter fetches everything. - """ - return self.include_others and not self.types - - def has_wildcards(self) -> bool: - """Whether the filter includes wildcards or is attempting to fetch - specific state. - - Returns: - True if the filter includes wildcards. - """ - - return self.include_others or any( - state_keys is None for state_keys in self.types.values() - ) - - def concrete_types(self) -> List[Tuple[str, str]]: - """Returns a list of concrete type/state_keys (i.e. not None) that - will be fetched. This will be a complete list if `has_wildcards` - returns False, but otherwise will be a subset (or even empty). - - Returns: - A list of type/state_keys tuples. - """ - return [ - (t, s) - for t, state_keys in self.types.items() - if state_keys is not None - for s in state_keys - ] - - def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: - """Return the filter split into two: one which assumes it's exclusively - matching against member state, and one which assumes it's matching - against non member state. - - This is useful due to the returned filters giving correct results for - `is_full()`, `has_wildcards()`, etc, when operating against maps that - either exclusively contain member events or only contain non-member - events. (Which is the case when dealing with the member vs non-member - state caches). - - Returns: - The member and non member filters - """ - - if EventTypes.Member in self.types: - state_keys = self.types[EventTypes.Member] - if state_keys is None: - member_filter = StateFilter.all() - else: - member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) - elif self.include_others: - member_filter = StateFilter.all() - else: - member_filter = StateFilter.none() - - non_member_filter = StateFilter( - types=frozendict( - {k: v for k, v in self.types.items() if k != EventTypes.Member} - ), - include_others=self.include_others, - ) - - return member_filter, non_member_filter - - def _decompose_into_four_parts( - self, - ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: - """ - Decomposes this state filter into 4 constituent parts, which can be - thought of as this: - all? - minus_wildcards + plus_wildcards + plus_state_keys - - where - * all represents ALL state - * minus_wildcards represents entire state types to remove - * plus_wildcards represents entire state types to add - * plus_state_keys represents individual state keys to add - - See `recompose_from_four_parts` for the other direction of this - correspondence. - """ - is_all = self.include_others - excluded_types: Set[str] = {t for t in self.types if is_all} - wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} - concrete_keys: Set[StateKey] = set(self.concrete_types()) - - return (is_all, excluded_types), (wildcard_types, concrete_keys) - - @staticmethod - def _recompose_from_four_parts( - all_part: bool, - minus_wildcards: Set[str], - plus_wildcards: Set[str], - plus_state_keys: Set[StateKey], - ) -> "StateFilter": - """ - Recomposes a state filter from 4 parts. - - See `decompose_into_four_parts` (the other direction of this - correspondence) for descriptions on each of the parts. - """ - - # {state type -> set of state keys OR None for wildcard} - # (The same structure as that of a StateFilter.) - new_types: Dict[str, Optional[Set[str]]] = {} - - # if we start with all, insert the excluded statetypes as empty sets - # to prevent them from being included - if all_part: - new_types.update({state_type: set() for state_type in minus_wildcards}) - - # insert the plus wildcards - new_types.update({state_type: None for state_type in plus_wildcards}) - - # insert the specific state keys - for state_type, state_key in plus_state_keys: - if state_type in new_types: - entry = new_types[state_type] - if entry is not None: - entry.add(state_key) - elif not all_part: - # don't insert if the entire type is already included by - # include_others as this would actually shrink the state allowed - # by this filter. - new_types[state_type] = {state_key} - - return StateFilter.freeze(new_types, include_others=all_part) - - def approx_difference(self, other: "StateFilter") -> "StateFilter": - """ - Returns a state filter which represents `self - other`. - - This is useful for determining what state remains to be pulled out of the - database if we want the state included by `self` but already have the state - included by `other`. - - The returned state filter - - MUST include all state events that are included by this filter (`self`) - unless they are included by `other`; - - MUST NOT include state events not included by this filter (`self`); and - - MAY be an over-approximation: the returned state filter - MAY additionally include some state events from `other`. - - This implementation attempts to return the narrowest such state filter. - In the case that `self` contains wildcards for state types where - `other` contains specific state keys, an approximation must be made: - the returned state filter keeps the wildcard, as state filters are not - able to express 'all state keys except some given examples'. - e.g. - StateFilter(m.room.member -> None (wildcard)) - minus - StateFilter(m.room.member -> {'@wombat:example.org'}) - is approximated as - StateFilter(m.room.member -> None (wildcard)) - """ - - # We first transform self and other into an alternative representation: - # - whether or not they include all events to begin with ('all') - # - if so, which event types are excluded? ('excludes') - # - which entire event types to include ('wildcards') - # - which concrete state keys to include ('concrete state keys') - (self_all, self_excludes), ( - self_wildcards, - self_concrete_keys, - ) = self._decompose_into_four_parts() - (other_all, other_excludes), ( - other_wildcards, - other_concrete_keys, - ) = other._decompose_into_four_parts() - - # Start with an estimate of the difference based on self - new_all = self_all - # Wildcards from the other can be added to the exclusion filter - new_excludes = self_excludes | other_wildcards - # We remove wildcards that appeared as wildcards in the other - new_wildcards = self_wildcards - other_wildcards - # We filter out the concrete state keys that appear in the other - # as wildcards or concrete state keys. - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in self_concrete_keys - if state_type not in other_wildcards - } - other_concrete_keys - - if other_all: - if self_all: - # If self starts with all, then we add as wildcards any - # types which appear in the other's exclusion filter (but - # aren't in the self exclusion filter). This is as the other - # filter will return everything BUT the types in its exclusion, so - # we need to add those excluded types that also match the self - # filter as wildcard types in the new filter. - new_wildcards |= other_excludes.difference(self_excludes) - - # If other is an `include_others` then the difference isn't. - new_all = False - # (We have no need for excludes when we don't start with all, as there - # is nothing to exclude.) - new_excludes = set() - - # We also filter out all state types that aren't in the exclusion - # list of the other. - new_wildcards &= other_excludes - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in new_concrete_keys - if state_type in other_excludes - } - - # Transform our newly-constructed state filter from the alternative - # representation back into the normal StateFilter representation. - return StateFilter._recompose_from_four_parts( - new_all, new_excludes, new_wildcards, new_concrete_keys - ) - - def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: - """Check if we need to wait for full state to complete to calculate this state - - If we have a state filter which is completely satisfied even with partial - state, then we don't need to await_full_state before we can return it. - - Args: - is_mine_id: a callable which confirms if a given state_key matches a mxid - of a local user - """ - # if we haven't requested membership events, then it depends on the value of - # 'include_others' - if EventTypes.Member not in self.types: - return self.include_others - - # if we're looking for *all* membership events, then we have to wait - member_state_keys = self.types[EventTypes.Member] - if member_state_keys is None: - return True - - # otherwise, consider whose membership we are looking for. If it's entirely - # local users, then we don't need to wait. - for state_key in member_state_keys: - if not is_mine_id(state_key): - # remote user - return True - - # local users only - return False - - -_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) -_ALL_NON_MEMBER_STATE_FILTER = StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True -) -_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/types.py b/synapse/types.py deleted file mode 100644 index f2d436ddc3..0000000000 --- a/synapse/types.py +++ /dev/null @@ -1,928 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -import re -import string -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - List, - Mapping, - Match, - MutableMapping, - NoReturn, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) - -import attr -from frozendict import frozendict -from signedjson.key import decode_verify_key_bytes -from signedjson.types import VerifyKey -from typing_extensions import Final, TypedDict -from unpaddedbase64 import decode_base64 -from zope.interface import Interface - -from twisted.internet.defer import CancelledError -from twisted.internet.interfaces import ( - IReactorCore, - IReactorPluggableNameResolver, - IReactorSSL, - IReactorTCP, - IReactorThreads, - IReactorTime, -) - -from synapse.api.errors import Codes, SynapseError -from synapse.util.cancellation import cancellable -from synapse.util.stringutils import parse_and_validate_server_name - -if TYPE_CHECKING: - from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore, PurgeEventsStore - from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore - -# Define a state map type from type/state_key to T (usually an event ID or -# event) -T = TypeVar("T") -StateKey = Tuple[str, str] -StateMap = Mapping[StateKey, T] -MutableStateMap = MutableMapping[StateKey, T] - -# JSON types. These could be made stronger, but will do for now. -# A JSON-serialisable dict. -JsonDict = Dict[str, Any] -# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. -# Useful when you have a TypedDict which isn't going to be mutated and you don't want -# to cast to JsonDict everywhere. -JsonMapping = Mapping[str, Any] -# A JSON-serialisable object. -JsonSerializable = object - - -# Note that this seems to require inheriting *directly* from Interface in order -# for mypy-zope to realize it is an interface. -class ISynapseReactor( - IReactorTCP, - IReactorSSL, - IReactorPluggableNameResolver, - IReactorTime, - IReactorCore, - IReactorThreads, - Interface, -): - """The interfaces necessary for Synapse to function.""" - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class Requester: - """ - Represents the user making a request - - Attributes: - user: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request has been shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - """ - - user: "UserID" - access_token_id: Optional[int] - is_guest: bool - shadow_banned: bool - device_id: Optional[str] - app_service: Optional["ApplicationService"] - authenticated_entity: str - - def serialize(self) -> Dict[str, Any]: - """Converts self to a type that can be serialized as JSON, and then - deserialized by `deserialize` - - Returns: - dict - """ - return { - "user_id": self.user.to_string(), - "access_token_id": self.access_token_id, - "is_guest": self.is_guest, - "shadow_banned": self.shadow_banned, - "device_id": self.device_id, - "app_server_id": self.app_service.id if self.app_service else None, - "authenticated_entity": self.authenticated_entity, - } - - @staticmethod - def deserialize( - store: "ApplicationServiceWorkerStore", input: Dict[str, Any] - ) -> "Requester": - """Converts a dict that was produced by `serialize` back into a - Requester. - - Args: - store: Used to convert AS ID to AS object - input: A dict produced by `serialize` - - Returns: - Requester - """ - appservice = None - if input["app_server_id"]: - appservice = store.get_app_service_by_id(input["app_server_id"]) - - return Requester( - user=UserID.from_string(input["user_id"]), - access_token_id=input["access_token_id"], - is_guest=input["is_guest"], - shadow_banned=input["shadow_banned"], - device_id=input["device_id"], - app_service=appservice, - authenticated_entity=input["authenticated_entity"], - ) - - -def create_requester( - user_id: Union[str, "UserID"], - access_token_id: Optional[int] = None, - is_guest: bool = False, - shadow_banned: bool = False, - device_id: Optional[str] = None, - app_service: Optional["ApplicationService"] = None, - authenticated_entity: Optional[str] = None, -) -> Requester: - """ - Create a new ``Requester`` object - - Args: - user_id: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request is shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - - Returns: - Requester - """ - if not isinstance(user_id, UserID): - user_id = UserID.from_string(user_id) - - if authenticated_entity is None: - authenticated_entity = user_id.to_string() - - return Requester( - user_id, - access_token_id, - is_guest, - shadow_banned, - device_id, - app_service, - authenticated_entity, - ) - - -def get_domain_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[idx + 1 :] - - -def get_localpart_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[1:idx] - - -DS = TypeVar("DS", bound="DomainSpecificString") - - -@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) -class DomainSpecificString(metaclass=abc.ABCMeta): - """Common base class among ID/name strings that have a local part and a - domain name, prefixed with a sigil. - - Has the fields: - - 'localpart' : The local part of the name (without the leading sigil) - 'domain' : The domain part of the name - """ - - SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore - - localpart: str - domain: str - - # Because this is a frozen class, it is deeply immutable. - def __copy__(self: DS) -> DS: - return self - - def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: - return self - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0:1] != cls.SIGIL: - raise SynapseError( - 400, - "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - parts = s[1:].split(":", 1) - if len(parts) != 2: - raise SynapseError( - 400, - "Expected %s of the form '%slocalname:domain'" - % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - domain = parts[1] - # This code will need changing if we want to support multiple domain - # names on one HS - return cls(localpart=parts[0], domain=domain) - - def to_string(self) -> str: - """Return a string encoding the fields of the structure object.""" - return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) - - @classmethod - def is_valid(cls: Type[DS], s: str) -> bool: - """Parses the input string and attempts to ensure it is valid.""" - # TODO: this does not reject an empty localpart or an overly-long string. - # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar - try: - obj = cls.from_string(s) - # Apply additional validation to the domain. This is only done - # during is_valid (and not part of from_string) since it is - # possible for invalid data to exist in room-state, etc. - parse_and_validate_server_name(obj.domain) - return True - except Exception: - return False - - __repr__ = to_string - - -@attr.s(slots=True, frozen=True, repr=False) -class UserID(DomainSpecificString): - """Structure representing a user ID.""" - - SIGIL = "@" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomAlias(DomainSpecificString): - """Structure representing a room name.""" - - SIGIL = "#" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomID(DomainSpecificString): - """Structure representing a room id.""" - - SIGIL = "!" - - -@attr.s(slots=True, frozen=True, repr=False) -class EventID(DomainSpecificString): - """Structure representing an event id.""" - - SIGIL = "$" - - -mxid_localpart_allowed_characters = set( - "_-./=" + string.ascii_lowercase + string.digits -) - - -def contains_invalid_mxid_characters(localpart: str) -> bool: - """Check for characters not allowed in an mxid or groupid localpart - - Args: - localpart: the localpart to be checked - - Returns: - True if there are any naughty characters - """ - return any(c not in mxid_localpart_allowed_characters for c in localpart) - - -UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") - -# the following is a pattern which matches '=', and bytes which are not allowed in a mxid -# localpart. -# -# It works by: -# * building a string containing the allowed characters (excluding '=') -# * escaping every special character with a backslash (to stop '-' being interpreted as a -# range operator) -# * wrapping it in a '[^...]' regex -# * converting the whole lot to a 'bytes' sequence, so that we can use it to match -# bytes rather than strings -# -NON_MXID_CHARACTER_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( - "ascii" - ) -) - - -def map_username_to_mxid_localpart( - username: Union[str, bytes], case_sensitive: bool = False -) -> str: - """Map a username onto a string suitable for a MXID - - This follows the algorithm laid out at - https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. - - Args: - username: username to be mapped - case_sensitive: true if TEST and test should be mapped - onto different mxids - - Returns: - string suitable for a mxid localpart - """ - if not isinstance(username, bytes): - username = username.encode("utf-8") - - # first we sort out upper-case characters - if case_sensitive: - - def f1(m: Match[bytes]) -> bytes: - return b"_" + m.group().lower() - - username = UPPER_CASE_PATTERN.sub(f1, username) - else: - username = username.lower() - - # then we sort out non-ascii characters by converting to the hex equivalent. - def f2(m: Match[bytes]) -> bytes: - return b"=%02x" % (m.group()[0],) - - username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) - - # we also do the =-escaping to mxids starting with an underscore. - username = re.sub(b"^_", b"=5f", username) - - # we should now only have ascii bytes left, so can decode back to a string. - return username.decode("ascii") - - -@attr.s(frozen=True, slots=True, order=False) -class RoomStreamToken: - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] ▼ [1] ▼ [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream, events are ordered by - `stream_ordering` (when they arrived at the homeserver). - - When traversing historic events, events are first ordered by their `depth` - (`topological_ordering` in the event graph) and tie-broken by - `stream_ordering` (when the event arrived at the homeserver). - - If you're looking for more info about what a token with all of the - underscores means, ex. - `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring - for `StreamToken` below. - - --- - - Live tokens start with an "s" followed by the `stream_ordering` of the event - that comes before the position of the token. Said another way: - `stream_ordering` uniquely identifies a persisted event. The live token - means "the position just after the event identified by `stream_ordering`". - An example token is: - - s2633508 - - --- - - Historic tokens start with a "t" followed by the `depth` - (`topological_ordering` in the event graph) of the event that comes before - the position of the token, followed by "-", followed by the - `stream_ordering` of the event that comes before the position of the token. - An example token is: - - t426-2633508 - - --- - - There is also a third mode for live tokens where the token starts with "m", - which is sometimes used when using sharded event persisters. In this case - the events stream is considered to be a set of streams (one for each writer) - and the token encodes the vector clock of positions of each writer in their - respective streams. - - The format of the token in such case is an initial integer min position, - followed by the mapping of instance ID to position separated by '.' and '~': - - m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... - - The `min_pos` corresponds to the minimum position all writers have persisted - up to, and then only writers that are ahead of that position need to be - encoded. An example token is: - - m56~2.58~3.59 - - Which corresponds to a set of three (or more writers) where instances 2 and - 3 (these are instance IDs that can be looked up in the DB to fetch the more - commonly used instance names) are at positions 58 and 59 respectively, and - all other instances are at position 56. - - Note: The `RoomStreamToken` cannot have both a topological part and an - instance map. - - --- - - For caching purposes, `RoomStreamToken`s and by extension, all their - attributes, must be hashable. - """ - - topological: Optional[int] = attr.ib( - validator=attr.validators.optional(attr.validators.instance_of(int)), - ) - stream: int = attr.ib(validator=attr.validators.instance_of(int)) - - instance_map: "frozendict[str, int]" = attr.ib( - factory=frozendict, - validator=attr.validators.deep_mapping( - key_validator=attr.validators.instance_of(str), - value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(frozendict), - ), - ) - - def __attrs_post_init__(self) -> None: - """Validates that both `topological` and `instance_map` aren't set.""" - - if self.instance_map and self.topological: - raise ValueError( - "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." - ) - - @classmethod - async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - if string[0] == "t": - parts = string[1:].split("-", 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - if string[0] == "m": - parts = string[1:].split("~") - stream = int(parts[0]) - - instance_map = {} - for part in parts[1:]: - key, value = part.split(".") - instance_id = int(key) - pos = int(value) - - instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] - instance_map[instance_name] = pos - - return cls( - topological=None, - stream=stream, - instance_map=frozendict(instance_map), - ) - except CancelledError: - raise - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": - """Return a new token such that if an event is after both this token and - the other token, then its after the returned token too. - """ - - if self.topological or other.topological: - raise Exception("Can't advance topological tokens") - - max_stream = max(self.stream, other.stream) - - instance_map = { - instance: max( - self.instance_map.get(instance, self.stream), - other.instance_map.get(instance, other.stream), - ) - for instance in set(self.instance_map).union(other.instance_map) - } - - return RoomStreamToken(None, max_stream, frozendict(instance_map)) - - def as_historical_tuple(self) -> Tuple[int, int]: - """Returns a tuple of `(topological, stream)` for historical tokens. - - Raises if not an historical token (i.e. doesn't have a topological part). - """ - if self.topological is None: - raise Exception( - "Cannot call `RoomStreamToken.as_historical_tuple` on live token" - ) - - return self.topological, self.stream - - def get_stream_pos_for_instance(self, instance_name: str) -> int: - """Get the stream position that the given writer was at at this token. - - This only makes sense for "live" tokens that may have a vector clock - component, and so asserts that this is a "live" token. - """ - assert self.topological is None - - # If we don't have an entry for the instance we can assume that it was - # at `self.stream`. - return self.instance_map.get(instance_name, self.stream) - - def get_max_stream_pos(self) -> int: - """Get the maximum stream position referenced in this token. - - The corresponding "min" position is, by definition just `self.stream`. - - This is used to handle tokens that have non-empty `instance_map`, and so - reference stream positions after the `self.stream` position. - """ - return max(self.instance_map.values(), default=self.stream) - - async def to_string(self, store: "DataStore") -> str: - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - elif self.instance_map: - entries = [] - for name, pos in self.instance_map.items(): - instance_id = await store.get_id_for_instance(name) - entries.append(f"{instance_id}.{pos}") - - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" - else: - return "s%d" % (self.stream,) - - -class StreamKeyType: - """Known stream types. - - A stream is a list of entities ordered by an incrementing "stream token". - """ - - ROOM: Final = "room_key" - PRESENCE: Final = "presence_key" - TYPING: Final = "typing_key" - RECEIPT: Final = "receipt_key" - ACCOUNT_DATA: Final = "account_data_key" - PUSH_RULES: Final = "push_rules_key" - TO_DEVICE: Final = "to_device_key" - DEVICE_LIST: Final = "device_list_key" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StreamToken: - """A collection of keys joined together by underscores in the following - order and which represent the position in their respective streams. - - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` - 1. `room_key`: `s2633508` which is a `RoomStreamToken` - - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - - See the docstring for `RoomStreamToken` for more details. - 2. `presence_key`: `17` - 3. `typing_key`: `338` - 4. `receipt_key`: `6732159` - 5. `account_data_key`: `1082514` - 6. `push_rules_key`: `541479` - 7. `to_device_key`: `274711` - 8. `device_list_key`: `265584` - 9. `groups_key`: `1` (note that this key is now unused) - - You can see how many of these keys correspond to the various - fields in a "/sync" response: - ```json - { - "next_batch": "s12_4_0_1_1_1_1_4_1", - "presence": { - "events": [] - }, - "device_lists": { - "changed": [] - }, - "rooms": { - "join": { - "!QrZlfIDQLNLdZHqTnt:hs1": { - "timeline": { - "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1", - "limited": false - }, - "state": { - "events": [] - }, - "account_data": { - "events": [] - }, - "ephemeral": { - "events": [] - } - } - } - } - } - ``` - - --- - - For caching purposes, `StreamToken`s and by extension, all their attributes, - must be hashable. - """ - - room_key: RoomStreamToken = attr.ib( - validator=attr.validators.instance_of(RoomStreamToken) - ) - presence_key: int - typing_key: int - receipt_key: int - account_data_key: int - push_rules_key: int - to_device_key: int - device_list_key: int - # Note that the groups key is no longer used and may have bogus values. - groups_key: int - - _SEPARATOR = "_" - START: ClassVar["StreamToken"] - - @classmethod - @cancellable - async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": - """ - Creates a RoomStreamToken from its textual representation. - """ - try: - keys = string.split(cls._SEPARATOR) - while len(keys) < len(attr.fields(cls)): - # i.e. old token from before receipt_key - keys.append("0") - return cls( - await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) - ) - except CancelledError: - raise - except Exception: - raise SynapseError(400, "Invalid stream token") - - async def to_string(self, store: "DataStore") -> str: - return self._SEPARATOR.join( - [ - await self.room_key.to_string(store), - str(self.presence_key), - str(self.typing_key), - str(self.receipt_key), - str(self.account_data_key), - str(self.push_rules_key), - str(self.to_device_key), - str(self.device_list_key), - # Note that the groups key is no longer used, but it is still - # serialized so that there will not be confusion in the future - # if additional tokens are added. - str(self.groups_key), - ] - ) - - @property - def room_stream_id(self) -> int: - return self.room_key.stream - - def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": - """Advance the given key in the token to a new value if and only if the - new value is after the old value. - - :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. - """ - if key == StreamKeyType.ROOM: - new_token = self.copy_and_replace( - StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) - ) - return new_token - - new_token = self.copy_and_replace(key, new_value) - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) - - if old_id < new_id: - return new_token - else: - return self - - def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": - return attr.evolve(self, **{key: new_value}) - - -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class PersistedEventPosition: - """Position of a newly persisted event with instance that persisted it. - - This can be used to test whether the event is persisted before or after a - RoomStreamToken. - """ - - instance_name: str - stream: int - - def persisted_after(self, token: RoomStreamToken) -> bool: - return token.get_stream_pos_for_instance(self.instance_name) < self.stream - - def to_room_stream_token(self) -> RoomStreamToken: - """Converts the position to a room stream token such that events - persisted in the same room after this position will be after the - returned `RoomStreamToken`. - - Note: no guarantees are made about ordering w.r.t. events in other - rooms. - """ - # Doing the naive thing satisfies the desired properties described in - # the docstring. - return RoomStreamToken(None, self.stream) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThirdPartyInstanceID: - appservice_id: Optional[str] - network_id: Optional[str] - - # Deny iteration because it will bite you if you try to create a singleton - # set by: - # users = set(user) - def __iter__(self) -> NoReturn: - raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - - # Because this class is a frozen class, it is deeply immutable. - def __copy__(self) -> "ThirdPartyInstanceID": - return self - - def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": - return self - - @classmethod - def from_string(cls, s: str) -> "ThirdPartyInstanceID": - bits = s.split("|", 2) - if len(bits) != 2: - raise SynapseError(400, "Invalid ID %r" % (s,)) - - return cls(appservice_id=bits[0], network_id=bits[1]) - - def to_string(self) -> str: - return "%s|%s" % (self.appservice_id, self.network_id) - - __str__ = to_string - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ReadReceipt: - """Information about a read-receipt""" - - room_id: str - receipt_type: str - user_id: str - event_ids: List[str] - thread_id: Optional[str] - data: JsonDict - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceListUpdates: - """ - An object containing a diff of information regarding other users' device lists, intended for - a recipient to carry out device list tracking. - - Attributes: - changed: A set of users whose device lists have changed recently. - left: A set of users who the recipient no longer needs to track the device lists of. - Typically when those users no longer share any end-to-end encryption enabled rooms. - """ - - # We need to use a factory here, otherwise `set` is not evaluated at - # object instantiation, but instead at class definition instantiation. - # The latter happening only once, thus always giving you the same sets - # across multiple DeviceListUpdates instances. - # Also see: don't define mutable default arguments. - changed: Set[str] = attr.ib(factory=set) - left: Set[str] = attr.ib(factory=set) - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - -def get_verify_key_from_cross_signing_key( - key_info: Mapping[str, Any] -) -> Tuple[str, VerifyKey]: - """Get the key ID and signedjson verify key from a cross-signing key dict - - Args: - key_info: a cross-signing key dict, which must have a "keys" - property that has exactly one item in it - - Returns: - the key ID and verify key for the cross-signing key - """ - # make sure that a `keys` field is provided - if "keys" not in key_info: - raise ValueError("Invalid key") - keys = key_info["keys"] - # and that it contains exactly one key - if len(keys) == 1: - key_id, key_data = next(iter(keys.items())) - return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) - else: - raise ValueError("Invalid key") - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. - - Attributes: - user_id: ID of the user. - appservice_id: Application service ID that created this user. - consent_server_notice_sent: Version of policy documents the user has been sent. - consent_version: Version of policy documents the user has consented to. - creation_ts: Creation timestamp of the user. - is_admin: True if the user is an admin. - is_deactivated: True if the user has been deactivated. - is_guest: True if the user is a guest user. - is_shadow_banned: True if the user has been shadow-banned. - user_type: User type (None for normal user, 'support' and 'bot' other options). - """ - - user_id: UserID - appservice_id: Optional[int] - consent_server_notice_sent: Optional[str] - consent_version: Optional[str] - user_type: Optional[str] - creation_ts: int - is_admin: bool - is_deactivated: bool - is_guest: bool - is_shadow_banned: bool - - -class UserProfile(TypedDict): - user_id: str - display_name: Optional[str] - avatar_url: Optional[str] - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class RetentionPolicy: - min_lifetime: Optional[int] = None - max_lifetime: Optional[int] = None diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py new file mode 100644 index 0000000000..f2d436ddc3 --- /dev/null +++ b/synapse/types/__init__.py @@ -0,0 +1,928 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import re +import string +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Mapping, + Match, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import attr +from frozendict import frozendict +from signedjson.key import decode_verify_key_bytes +from signedjson.types import VerifyKey +from typing_extensions import Final, TypedDict +from unpaddedbase64 import decode_base64 +from zope.interface import Interface + +from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import ( + IReactorCore, + IReactorPluggableNameResolver, + IReactorSSL, + IReactorTCP, + IReactorThreads, + IReactorTime, +) + +from synapse.api.errors import Codes, SynapseError +from synapse.util.cancellation import cancellable +from synapse.util.stringutils import parse_and_validate_server_name + +if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService + from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] + +# JSON types. These could be made stronger, but will do for now. +# A JSON-serialisable dict. +JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] +# A JSON-serialisable object. +JsonSerializable = object + + +# Note that this seems to require inheriting *directly* from Interface in order +# for mypy-zope to realize it is an interface. +class ISynapseReactor( + IReactorTCP, + IReactorSSL, + IReactorPluggableNameResolver, + IReactorTime, + IReactorCore, + IReactorThreads, + Interface, +): + """The interfaces necessary for Synapse to function.""" + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class Requester: + """ + Represents the user making a request + + Attributes: + user: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + """ + + user: "UserID" + access_token_id: Optional[int] + is_guest: bool + shadow_banned: bool + device_id: Optional[str] + app_service: Optional["ApplicationService"] + authenticated_entity: str + + def serialize(self) -> Dict[str, Any]: + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Returns: + dict + """ + return { + "user_id": self.user.to_string(), + "access_token_id": self.access_token_id, + "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, + "device_id": self.device_id, + "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, + } + + @staticmethod + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": + """Converts a dict that was produced by `serialize` back into a + Requester. + + Args: + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` + + Returns: + Requester + """ + appservice = None + if input["app_server_id"]: + appservice = store.get_app_service_by_id(input["app_server_id"]) + + return Requester( + user=UserID.from_string(input["user_id"]), + access_token_id=input["access_token_id"], + is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], + device_id=input["device_id"], + app_service=appservice, + authenticated_entity=input["authenticated_entity"], + ) + + +def create_requester( + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: bool = False, + shadow_banned: bool = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, +) -> Requester: + """ + Create a new ``Requester`` object + + Args: + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + + return Requester( + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, + ) + + +def get_domain_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[idx + 1 :] + + +def get_localpart_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] + + +DS = TypeVar("DS", bound="DomainSpecificString") + + +@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) +class DomainSpecificString(metaclass=abc.ABCMeta): + """Common base class among ID/name strings that have a local part and a + domain name, prefixed with a sigil. + + Has the fields: + + 'localpart' : The local part of the name (without the leading sigil) + 'domain' : The domain part of the name + """ + + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore + + localpart: str + domain: str + + # Because this is a frozen class, it is deeply immutable. + def __copy__(self: DS) -> DS: + return self + + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: + return self + + @classmethod + def from_string(cls: Type[DS], s: str) -> DS: + """Parse the string given by 's' into a structure object.""" + if len(s) < 1 or s[0:1] != cls.SIGIL: + raise SynapseError( + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + parts = s[1:].split(":", 1) + if len(parts) != 2: + raise SynapseError( + 400, + "Expected %s of the form '%slocalname:domain'" + % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + domain = parts[1] + # This code will need changing if we want to support multiple domain + # names on one HS + return cls(localpart=parts[0], domain=domain) + + def to_string(self) -> str: + """Return a string encoding the fields of the structure object.""" + return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + + @classmethod + def is_valid(cls: Type[DS], s: str) -> bool: + """Parses the input string and attempts to ensure it is valid.""" + # TODO: this does not reject an empty localpart or an overly-long string. + # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar + try: + obj = cls.from_string(s) + # Apply additional validation to the domain. This is only done + # during is_valid (and not part of from_string) since it is + # possible for invalid data to exist in room-state, etc. + parse_and_validate_server_name(obj.domain) + return True + except Exception: + return False + + __repr__ = to_string + + +@attr.s(slots=True, frozen=True, repr=False) +class UserID(DomainSpecificString): + """Structure representing a user ID.""" + + SIGIL = "@" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomAlias(DomainSpecificString): + """Structure representing a room name.""" + + SIGIL = "#" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomID(DomainSpecificString): + """Structure representing a room id.""" + + SIGIL = "!" + + +@attr.s(slots=True, frozen=True, repr=False) +class EventID(DomainSpecificString): + """Structure representing an event id.""" + + SIGIL = "$" + + +mxid_localpart_allowed_characters = set( + "_-./=" + string.ascii_lowercase + string.digits +) + + +def contains_invalid_mxid_characters(localpart: str) -> bool: + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart: the localpart to be checked + + Returns: + True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + +UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") + +# the following is a pattern which matches '=', and bytes which are not allowed in a mxid +# localpart. +# +# It works by: +# * building a string containing the allowed characters (excluding '=') +# * escaping every special character with a backslash (to stop '-' being interpreted as a +# range operator) +# * wrapping it in a '[^...]' regex +# * converting the whole lot to a 'bytes' sequence, so that we can use it to match +# bytes rather than strings +# +NON_MXID_CHARACTER_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( + "ascii" + ) +) + + +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: + """Map a username onto a string suitable for a MXID + + This follows the algorithm laid out at + https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. + + Args: + username: username to be mapped + case_sensitive: true if TEST and test should be mapped + onto different mxids + + Returns: + string suitable for a mxid localpart + """ + if not isinstance(username, bytes): + username = username.encode("utf-8") + + # first we sort out upper-case characters + if case_sensitive: + + def f1(m: Match[bytes]) -> bytes: + return b"_" + m.group().lower() + + username = UPPER_CASE_PATTERN.sub(f1, username) + else: + username = username.lower() + + # then we sort out non-ascii characters by converting to the hex equivalent. + def f2(m: Match[bytes]) -> bytes: + return b"=%02x" % (m.group()[0],) + + username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) + + # we also do the =-escaping to mxids starting with an underscore. + username = re.sub(b"^_", b"=5f", username) + + # we should now only have ascii bytes left, so can decode back to a string. + return username.decode("ascii") + + +@attr.s(frozen=True, slots=True, order=False) +class RoomStreamToken: + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] ▼ [1] ▼ [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream, events are ordered by + `stream_ordering` (when they arrived at the homeserver). + + When traversing historic events, events are first ordered by their `depth` + (`topological_ordering` in the event graph) and tie-broken by + `stream_ordering` (when the event arrived at the homeserver). + + If you're looking for more info about what a token with all of the + underscores means, ex. + `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring + for `StreamToken` below. + + --- + + Live tokens start with an "s" followed by the `stream_ordering` of the event + that comes before the position of the token. Said another way: + `stream_ordering` uniquely identifies a persisted event. The live token + means "the position just after the event identified by `stream_ordering`". + An example token is: + + s2633508 + + --- + + Historic tokens start with a "t" followed by the `depth` + (`topological_ordering` in the event graph) of the event that comes before + the position of the token, followed by "-", followed by the + `stream_ordering` of the event that comes before the position of the token. + An example token is: + + t426-2633508 + + --- + + There is also a third mode for live tokens where the token starts with "m", + which is sometimes used when using sharded event persisters. In this case + the events stream is considered to be a set of streams (one for each writer) + and the token encodes the vector clock of positions of each writer in their + respective streams. + + The format of the token in such case is an initial integer min position, + followed by the mapping of instance ID to position separated by '.' and '~': + + m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... + + The `min_pos` corresponds to the minimum position all writers have persisted + up to, and then only writers that are ahead of that position need to be + encoded. An example token is: + + m56~2.58~3.59 + + Which corresponds to a set of three (or more writers) where instances 2 and + 3 (these are instance IDs that can be looked up in the DB to fetch the more + commonly used instance names) are at positions 58 and 59 respectively, and + all other instances are at position 56. + + Note: The `RoomStreamToken` cannot have both a topological part and an + instance map. + + --- + + For caching purposes, `RoomStreamToken`s and by extension, all their + attributes, must be hashable. + """ + + topological: Optional[int] = attr.ib( + validator=attr.validators.optional(attr.validators.instance_of(int)), + ) + stream: int = attr.ib(validator=attr.validators.instance_of(int)) + + instance_map: "frozendict[str, int]" = attr.ib( + factory=frozendict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(frozendict), + ), + ) + + def __attrs_post_init__(self) -> None: + """Validates that both `topological` and `instance_map` aren't set.""" + + if self.instance_map and self.topological: + raise ValueError( + "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." + ) + + @classmethod + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + if string[0] == "t": + parts = string[1:].split("-", 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] + instance_map[instance_name] = pos + + return cls( + topological=None, + stream=stream, + instance_map=frozendict(instance_map), + ) + except CancelledError: + raise + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, frozendict(instance_map)) + + def as_historical_tuple(self) -> Tuple[int, int]: + """Returns a tuple of `(topological, stream)` for historical tokens. + + Raises if not an historical token (i.e. doesn't have a topological part). + """ + if self.topological is None: + raise Exception( + "Cannot call `RoomStreamToken.as_historical_tuple` on live token" + ) + + return self.topological, self.stream + + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token. + + This only makes sense for "live" tokens that may have a vector clock + component, and so asserts that this is a "live" token. + """ + assert self.topological is None + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + + async def to_string(self, store: "DataStore") -> str: + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + instance_id = await store.get_id_for_instance(name) + entries.append(f"{instance_id}.{pos}") + + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + else: + return "s%d" % (self.stream,) + + +class StreamKeyType: + """Known stream types. + + A stream is a list of entities ordered by an incrementing "stream token". + """ + + ROOM: Final = "room_key" + PRESENCE: Final = "presence_key" + TYPING: Final = "typing_key" + RECEIPT: Final = "receipt_key" + ACCOUNT_DATA: Final = "account_data_key" + PUSH_RULES: Final = "push_rules_key" + TO_DEVICE: Final = "to_device_key" + DEVICE_LIST: Final = "device_list_key" + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StreamToken: + """A collection of keys joined together by underscores in the following + order and which represent the position in their respective streams. + + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + 1. `room_key`: `s2633508` which is a `RoomStreamToken` + - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` + - See the docstring for `RoomStreamToken` for more details. + 2. `presence_key`: `17` + 3. `typing_key`: `338` + 4. `receipt_key`: `6732159` + 5. `account_data_key`: `1082514` + 6. `push_rules_key`: `541479` + 7. `to_device_key`: `274711` + 8. `device_list_key`: `265584` + 9. `groups_key`: `1` (note that this key is now unused) + + You can see how many of these keys correspond to the various + fields in a "/sync" response: + ```json + { + "next_batch": "s12_4_0_1_1_1_1_4_1", + "presence": { + "events": [] + }, + "device_lists": { + "changed": [] + }, + "rooms": { + "join": { + "!QrZlfIDQLNLdZHqTnt:hs1": { + "timeline": { + "events": [], + "prev_batch": "s10_4_0_1_1_1_1_4_1", + "limited": false + }, + "state": { + "events": [] + }, + "account_data": { + "events": [] + }, + "ephemeral": { + "events": [] + } + } + } + } + } + ``` + + --- + + For caching purposes, `StreamToken`s and by extension, all their attributes, + must be hashable. + """ + + room_key: RoomStreamToken = attr.ib( + validator=attr.validators.instance_of(RoomStreamToken) + ) + presence_key: int + typing_key: int + receipt_key: int + account_data_key: int + push_rules_key: int + to_device_key: int + device_list_key: int + # Note that the groups key is no longer used and may have bogus values. + groups_key: int + + _SEPARATOR = "_" + START: ClassVar["StreamToken"] + + @classmethod + @cancellable + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + """ + Creates a RoomStreamToken from its textual representation. + """ + try: + keys = string.split(cls._SEPARATOR) + while len(keys) < len(attr.fields(cls)): + # i.e. old token from before receipt_key + keys.append("0") + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) + except CancelledError: + raise + except Exception: + raise SynapseError(400, "Invalid stream token") + + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. + str(self.groups_key), + ] + ) + + @property + def room_stream_id(self) -> int: + return self.room_key.stream + + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. + """ + if key == StreamKeyType.ROOM: + new_token = self.copy_and_replace( + StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + + if old_id < new_id: + return new_token + else: + return self + + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": + return attr.evolve(self, **{key: new_value}) + + +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name: str + stream: int + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.get_stream_pos_for_instance(self.instance_name) < self.stream + + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarantees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + + # Deny iteration because it will bite you if you try to create a singleton + # set by: + # users = set(user) + def __iter__(self) -> NoReturn: + raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) + + # Because this class is a frozen class, it is deeply immutable. + def __copy__(self) -> "ThirdPartyInstanceID": + return self + + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": + return self + + @classmethod + def from_string(cls, s: str) -> "ThirdPartyInstanceID": + bits = s.split("|", 2) + if len(bits) != 2: + raise SynapseError(400, "Invalid ID %r" % (s,)) + + return cls(appservice_id=bits[0], network_id=bits[1]) + + def to_string(self) -> str: + return "%s|%s" % (self.appservice_id, self.network_id) + + __str__ = to_string + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ReadReceipt: + """Information about a read-receipt""" + + room_id: str + receipt_type: str + user_id: str + event_ids: List[str] + thread_id: Optional[str] + data: JsonDict + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info: a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + the key ID and verify key for the cross-signing key + """ + # make sure that a `keys` field is provided + if "keys" not in key_info: + raise ValueError("Invalid key") + keys = key_info["keys"] + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) + return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RetentionPolicy: + min_lifetime: Optional[int] = None + max_lifetime: Optional[int] = None diff --git a/synapse/types/state.py b/synapse/types/state.py new file mode 100644 index 0000000000..0004d955b4 --- /dev/null +++ b/synapse/types/state.py @@ -0,0 +1,567 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, +) + +import attr +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types import MutableStateMap, StateKey, StateMap + +if TYPE_CHECKING: + from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad + + +logger = logging.getLogger(__name__) + +# Used for generic functions below +T = TypeVar("T") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StateFilter: + """A filter used when querying for state. + + Attributes: + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not + appear in `types`. + """ + + types: "frozendict[str, Optional[FrozenSet[str]]]" + include_others: bool = False + + def __attrs_post_init__(self) -> None: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + # this is needed to work around the fact that StateFilter is frozen + object.__setattr__( + self, + "types", + frozendict({k: v for k, v in self.types.items() if v is not None}), + ) + + @staticmethod + def all() -> "StateFilter": + """Returns a filter that fetches everything. + + Returns: + The state filter. + """ + return _ALL_STATE_FILTER + + @staticmethod + def none() -> "StateFilter": + """Returns a filter that fetches nothing. + + Returns: + The new state filter. + """ + return _NONE_STATE_FILTER + + @staticmethod + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": + """Creates a filter that only fetches the given types + + Args: + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type + + Returns: + The new state filter. + """ + type_dict: Dict[str, Optional[Set[str]]] = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) # type: ignore + + return StateFilter( + types=frozendict( + (k, frozenset(v) if v is not None else None) + for k, v in type_dict.items() + ) + ) + + @staticmethod + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members: Set of user IDs + + Returns: + The new state filter + """ + return StateFilter( + types=frozendict({EventTypes.Member: frozenset(members)}), + include_others=True, + ) + + @staticmethod + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + + def return_expanded(self) -> "StateFilter": + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + The new state filter. + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in self.types.items() + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + elif EventTypes.Member in self.types: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), + include_others=True, + ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER + + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). + """ + + where_clause = "" + where_args: List[str] = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in self.types.items(): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self) -> Optional[int]: + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: + """Returns the state filtered with by this StateFilter. + + Args: + state: The state map to filter + + Returns: + The filtered state map. + This is a copy, so it's safe to mutate. + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in state_dict.items(): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self) -> bool: + """Whether this filter fetches everything or not + + Returns: + True if the filter fetches everything. + """ + return self.include_others and not self.types + + def has_wildcards(self) -> bool: + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + True if the filter includes wildcards. + """ + + return self.include_others or any( + state_keys is None for state_keys in self.types.values() + ) + + def concrete_types(self) -> List[Tuple[str, str]]: + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + A list of type/state_keys tuples. + """ + return [ + (t, s) + for t, state_keys in self.types.items() + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types=frozendict( + {k: v for k, v in self.types.items() if k != EventTypes.Member} + ), + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + + def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: + """Check if we need to wait for full state to complete to calculate this state + + If we have a state filter which is completely satisfied even with partial + state, then we don't need to await_full_state before we can return it. + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + # if we haven't requested membership events, then it depends on the value of + # 'include_others' + if EventTypes.Member not in self.types: + return self.include_others + + # if we're looking for *all* membership events, then we have to wait + member_state_keys = self.types[EventTypes.Member] + if member_state_keys is None: + return True + + # otherwise, consider whose membership we are looking for. If it's entirely + # local users, then we don't need to wait. + for state_key in member_state_keys: + if not is_mine_id(state_key): + # remote user + return True + + # local users only + return False + + +_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) +_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/visibility.py b/synapse/visibility.py index b443857571..e442de3173 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -26,8 +26,8 @@ from synapse.events.utils import prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import Clock logger = logging.getLogger(__name__) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d4e6d4236c..a433e70870 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -22,8 +22,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.server import HomeServer -from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, StateMap, UserID +from synapse.types.state import StateFilter from synapse.util import Clock from tests.unittest import HomeserverTestCase, TestCase -- cgit 1.5.1 From e2a1adbf5d11288f2134ced1f84c6ffdd91a9357 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 13 Dec 2022 00:54:46 +0000 Subject: Allow selecting "prejoin" events by state keys (#14642) * Declare new config * Parse new config * Read new config * Don't use trial/our TestCase where it's not needed Before: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m2.277s user 0m2.186s sys 0m0.083s ``` After: ``` $ time trial tests/events/test_utils.py > /dev/null real 0m0.566s user 0m0.508s sys 0m0.056s ``` * Helper to upsert to event fields without exceeding size limits. * Use helper when adding invite/knock state Now that we allow admins to include events in prejoin room state with arbitrary state keys, be a good Matrix citizen and ensure they don't accidentally create an oversized event. * Changelog * Move StateFilter tests should have done this in #14668 * Add extra methods to StateFilter * Use StateFilter * Ensure test file enforces typed defs; alphabetise * Workaround surprising get_current_state_ids * Whoops, fix mypy --- changelog.d/14642.feature | 1 + docs/usage/configuration/config_documentation.md | 57 ++- mypy.ini | 12 +- synapse/config/_util.py | 3 + synapse/config/api.py | 63 ++- synapse/events/utils.py | 32 +- synapse/handlers/message.py | 29 +- synapse/storage/databases/main/events_worker.py | 33 +- synapse/types/state.py | 18 + tests/config/test_api.py | 145 ++++++ tests/events/test_utils.py | 35 +- tests/storage/test_state.py | 623 +--------------------- tests/types/__init__.py | 0 tests/types/test_state.py | 627 +++++++++++++++++++++++ 14 files changed, 983 insertions(+), 695 deletions(-) create mode 100644 changelog.d/14642.feature create mode 100644 tests/config/test_api.py create mode 100644 tests/types/__init__.py create mode 100644 tests/types/test_state.py (limited to 'synapse/handlers/message.py') diff --git a/changelog.d/14642.feature b/changelog.d/14642.feature new file mode 100644 index 0000000000..cbc9db10c3 --- /dev/null +++ b/changelog.d/14642.feature @@ -0,0 +1 @@ +Allow selecting "prejoin" events by state keys in addition to event types. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index dc5e5ac597..4d32902fea 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -2501,32 +2501,53 @@ Config settings related to the client/server API --- ### `room_prejoin_state` -Controls for the state that is shared with users who receive an invite -to a room. By default, the following state event types are shared with users who -receive invites to the room: -- m.room.join_rules -- m.room.canonical_alias -- m.room.avatar -- m.room.encryption -- m.room.name -- m.room.create -- m.room.topic +This setting controls the state that is shared with users upon receiving an +invite to a room, or in reply to a knock on a room. By default, the following +state events are shared with users: + +- `m.room.join_rules` +- `m.room.canonical_alias` +- `m.room.avatar` +- `m.room.encryption` +- `m.room.name` +- `m.room.create` +- `m.room.topic` To change the default behavior, use the following sub-options: -* `disable_default_event_types`: set to true to disable the above defaults. If this - is enabled, only the event types listed in `additional_event_types` are shared. - Defaults to false. -* `additional_event_types`: Additional state event types to share with users when they are invited - to a room. By default, this list is empty (so only the default event types are shared). +* `disable_default_event_types`: boolean. Set to `true` to disable the above + defaults. If this is enabled, only the event types listed in + `additional_event_types` are shared. Defaults to `false`. +* `additional_event_types`: A list of additional state events to include in the + events to be shared. By default, this list is empty (so only the default event + types are shared). + + Each entry in this list should be either a single string or a list of two + strings. + * A standalone string `t` represents all events with type `t` (i.e. + with no restrictions on state keys). + * A pair of strings `[t, s]` represents a single event with type `t` and + state key `s`. The same type can appear in two entries with different state + keys: in this situation, both state keys are included in prejoin state. Example configuration: ```yaml room_prejoin_state: - disable_default_event_types: true + disable_default_event_types: false additional_event_types: - - org.example.custom.event.type - - m.room.join_rules + # Share all events of type `org.example.custom.event.typeA` + - org.example.custom.event.typeA + # Share only events of type `org.example.custom.event.typeB` whose + # state_key is "foo" + - ["org.example.custom.event.typeB", "foo"] + # Share only events of type `org.example.custom.event.typeC` whose + # state_key is "bar" or "baz" + - ["org.example.custom.event.typeC", "bar"] + - ["org.example.custom.event.typeC", "baz"] ``` + +*Changed in Synapse 1.74:* admins can filter the events in prejoin state based +on their state key. + --- ### `track_puppeted_user_ips` diff --git a/mypy.ini b/mypy.ini index 727536df50..37acf589c9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -89,6 +89,12 @@ disallow_untyped_defs = False [mypy-tests.*] disallow_untyped_defs = False +[mypy-tests.config.test_api] +disallow_untyped_defs = True + +[mypy-tests.federation.transport.test_client] +disallow_untyped_defs = True + [mypy-tests.handlers.test_sso] disallow_untyped_defs = True @@ -101,7 +107,7 @@ disallow_untyped_defs = True [mypy-tests.push.test_bulk_push_rule_evaluator] disallow_untyped_defs = True -[mypy-tests.test_server] +[mypy-tests.rest.*] disallow_untyped_defs = True [mypy-tests.state.test_profile] @@ -110,10 +116,10 @@ disallow_untyped_defs = True [mypy-tests.storage.*] disallow_untyped_defs = True -[mypy-tests.rest.*] +[mypy-tests.test_server] disallow_untyped_defs = True -[mypy-tests.federation.transport.test_client] +[mypy-tests.types.*] disallow_untyped_defs = True [mypy-tests.util.caches.*] diff --git a/synapse/config/_util.py b/synapse/config/_util.py index 3edb4b7106..d3a4b484ab 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -33,6 +33,9 @@ def validate_config( config: the configuration value to be validated config_path: the path within the config file. This will be used as a basis for the error message. + + Raises: + ConfigError, if validation fails. """ try: jsonschema.validate(config, json_schema) diff --git a/synapse/config/api.py b/synapse/config/api.py index e46728e73f..27d50d118f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -13,12 +13,13 @@ # limitations under the License. import logging -from typing import Any, Iterable +from typing import Any, Iterable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError from synapse.config._util import validate_config from synapse.types import JsonDict +from synapse.types.state import StateFilter logger = logging.getLogger(__name__) @@ -26,16 +27,20 @@ logger = logging.getLogger(__name__) class ApiConfig(Config): section = "api" + room_prejoin_state: StateFilter + track_puppetted_users_ips: bool + def read_config(self, config: JsonDict, **kwargs: Any) -> None: validate_config(_MAIN_SCHEMA, config, ()) - self.room_prejoin_state = list(self._get_prejoin_state_types(config)) + self.room_prejoin_state = StateFilter.from_types( + self._get_prejoin_state_entries(config) + ) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) - def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]: - """Get the event types to include in the prejoin state - - Parses the config and returns an iterable of the event types to be included. - """ + def _get_prejoin_state_entries( + self, config: JsonDict + ) -> Iterable[Tuple[str, Optional[str]]]: + """Get the event types and state keys to include in the prejoin state.""" room_prejoin_state_config = config.get("room_prejoin_state") or {} # backwards-compatibility support for room_invite_state_types @@ -50,33 +55,39 @@ class ApiConfig(Config): logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING) - yield from config["room_invite_state_types"] + for event_type in config["room_invite_state_types"]: + yield event_type, None return if not room_prejoin_state_config.get("disable_default_event_types"): - yield from _DEFAULT_PREJOIN_STATE_TYPES + yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS - yield from room_prejoin_state_config.get("additional_event_types", []) + for entry in room_prejoin_state_config.get("additional_event_types", []): + if isinstance(entry, str): + yield entry, None + else: + yield entry _ROOM_INVITE_STATE_TYPES_WARNING = """\ WARNING: The 'room_invite_state_types' configuration setting is now deprecated, and replaced with 'room_prejoin_state'. New features may not work correctly -unless 'room_invite_state_types' is removed. See the sample configuration file for -details of 'room_prejoin_state'. +unless 'room_invite_state_types' is removed. See the config documentation at + https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state +for details of 'room_prejoin_state'. -------------------------------------------------------------------------------- """ -_DEFAULT_PREJOIN_STATE_TYPES = [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, +_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [ + (EventTypes.JoinRules, ""), + (EventTypes.CanonicalAlias, ""), + (EventTypes.RoomAvatar, ""), + (EventTypes.RoomEncryption, ""), + (EventTypes.Name, ""), # Per MSC1772. - EventTypes.Create, + (EventTypes.Create, ""), # Per MSC3173. - EventTypes.Topic, + (EventTypes.Topic, ""), ] @@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = { "disable_default_event_types": {"type": "boolean"}, "additional_event_types": { "type": "array", - "items": {"type": "string"}, + "items": { + "oneOf": [ + {"type": "string"}, + { + "type": "array", + "items": {"type": "string"}, + "minItems": 2, + "maxItems": 2, + }, + ], + }, }, }, }, diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 71853caad8..13fa93afb8 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -28,8 +28,14 @@ from typing import ( ) import attr +from canonicaljson import encode_canonical_json -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + MAX_PDU_SIZE, + EventContentFields, + EventTypes, + RelationTypes, +) from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion from synapse.types import JsonDict @@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None: elif not isinstance(value, (bool, str)) and value is not None: # Other potential JSON values (bool, None, str) are safe. raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON) + + +def maybe_upsert_event_field( + event: EventBase, container: JsonDict, key: str, value: object +) -> bool: + """Upsert an event field, but only if this doesn't make the event too large. + + Returns true iff the upsert took place. + """ + if key in container: + old_value: object = container[key] + container[key] = value + # NB: here and below, we assume that passing a non-None `time_now` argument to + # get_pdu_json doesn't increase the size of the encoded result. + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + container[key] = old_value + else: + container[key] = value + upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE + if not upsert_okay: + del container[key] + + return upsert_okay diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d6e90ef259..845f683358 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext +from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler from synapse.logging import opentracing @@ -1739,12 +1740,15 @@ class EventCreationHandler: if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - event.unsigned[ - "invite_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, - membership_user_id=event.sender, + maybe_upsert_event_field( + event, + event.unsigned, + "invite_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + membership_user_id=event.sender, + ), ) invitee = UserID.from_string(event.state_key) @@ -1762,11 +1766,14 @@ class EventCreationHandler: event.signatures.update(returned_invite.signatures) if event.content["membership"] == Membership.KNOCK: - event.unsigned[ - "knock_room_state" - ] = await self.store.get_stripped_room_state_from_event_context( - context, - self.room_prejoin_state_types, + maybe_upsert_event_field( + event, + event.unsigned, + "knock_room_state", + await self.store.get_stripped_room_state_from_event_context( + context, + self.room_prejoin_state_types, + ), ) if event.type == EventTypes.Redaction: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 01e935edef..318fd7dc71 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -16,11 +16,11 @@ import logging import threading import weakref from enum import Enum, auto +from itertools import chain from typing import ( TYPE_CHECKING, Any, Collection, - Container, Dict, Iterable, List, @@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import ( ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: Container[str], + state_keys_to_include: StateFilter, membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore): Args: context: The event context to retrieve state of the room from. - state_types_to_include: The type of state events to include. + state_keys_to_include: The state events to include, for each event type. membership_user_id: An optional user ID to include the stripped membership state events of. This is useful when generating the stripped state of a room for invites. We want to send membership events of the inviter, so that the @@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore): Returns: A list of dictionaries, each representing a stripped state event from the room. """ - current_state_ids = await context.get_current_state_ids() + if membership_user_id: + types = chain( + state_keys_to_include.to_types(), + [(EventTypes.Member, membership_user_id)], + ) + filter = StateFilter.from_types(types) + else: + filter = state_keys_to_include + selected_state_ids = await context.get_current_state_ids(filter) # We know this event is not an outlier, so this must be # non-None. - assert current_state_ids is not None - - # The state to include - state_to_include_ids = [ - e_id - for k, e_id in current_state_ids.items() - if k[0] in state_types_to_include - or (membership_user_id and k == (EventTypes.Member, membership_user_id)) - ] + assert selected_state_ids is not None + + # Confusingly, get_current_state_events may return events that are discarded by + # the filter, if they're in context._state_delta_due_to_event. Strip these away. + selected_state_ids = filter.filter_state(selected_state_ids) - state_to_include = await self.get_events(state_to_include_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [ { diff --git a/synapse/types/state.py b/synapse/types/state.py index 0004d955b4..743a4f9217 100644 --- a/synapse/types/state.py +++ b/synapse/types/state.py @@ -118,6 +118,15 @@ class StateFilter: ) ) + def to_types(self) -> Iterable[Tuple[str, Optional[str]]]: + """The inverse to `from_types`.""" + for (event_type, state_keys) in self.types.items(): + if state_keys is None: + yield event_type, None + else: + for state_key in state_keys: + yield event_type, state_key + @staticmethod def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member @@ -343,6 +352,15 @@ class StateFilter: for s in state_keys ] + def wildcard_types(self) -> List[str]: + """Returns a list of event types which require us to fetch all state keys. + This will be empty unless `has_wildcards` returns True. + + Returns: + A list of event types. + """ + return [t for t, state_keys in self.types.items() if state_keys is None] + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching diff --git a/tests/config/test_api.py b/tests/config/test_api.py new file mode 100644 index 0000000000..6773c9a277 --- /dev/null +++ b/tests/config/test_api.py @@ -0,0 +1,145 @@ +from unittest import TestCase as StdlibTestCase + +import yaml + +from synapse.config import ConfigError +from synapse.config.api import ApiConfig +from synapse.types.state import StateFilter + +DEFAULT_PREJOIN_STATE_PAIRS = { + ("m.room.join_rules", ""), + ("m.room.canonical_alias", ""), + ("m.room.avatar", ""), + ("m.room.encryption", ""), + ("m.room.name", ""), + ("m.room.create", ""), + ("m.room.topic", ""), +} + + +class TestRoomPrejoinState(StdlibTestCase): + def read_config(self, source: str) -> ApiConfig: + config = ApiConfig() + config.read_config(yaml.safe_load(source)) + return config + + def test_no_prejoin_state(self) -> None: + config = self.read_config("foo: bar") + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS + ) + + def test_disable_default_event_types(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + """ + ) + self.assertEqual(config.room_prejoin_state, StateFilter.none()) + + def test_event_without_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + def test_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + """ + ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar")}, + ) + + def test_repeated_event_with_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - [foo, baz] + """ + ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar"), ("foo", "baz")}, + ) + + def test_no_specific_state_key_overrides_specific_state_key(self) -> None: + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - [foo, bar] + - foo + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + config = self.read_config( + """ +room_prejoin_state: + disable_default_event_types: true + additional_event_types: + - foo + - [foo, bar] + """ + ) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) + + def test_bad_event_type_entry_raises(self) -> None: + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [a, b, c] + """ + ) + + with self.assertRaises(ConfigError): + self.read_config( + """ +room_prejoin_state: + additional_event_types: + - [true, 1.23] + """ + ) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index b1c47efac7..a79256846f 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest as stdlib_unittest + from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.utils import ( SerializeEventConfig, copy_and_fixup_power_levels_contents, + maybe_upsert_event_field, prune_event, serialize_event, ) from synapse.util.frozenutils import freeze -from tests import unittest - def MockEvent(**kwargs): if "event_id" not in kwargs: @@ -34,7 +35,31 @@ def MockEvent(**kwargs): return make_event_from_dict(kwargs) -class PruneEventTestCase(unittest.TestCase): +class TestMaybeUpsertEventField(stdlib_unittest.TestCase): + def test_update_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + success = maybe_upsert_event_field(event, event.unsigned, "key", "value") + self.assertTrue(success) + self.assertEqual(event.unsigned["key"], "value") + + def test_update_not_okay(self) -> None: + event = make_event_from_dict({"event_id": "$1234"}) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertNotIn("key", event.unsigned) + + def test_update_not_okay_leaves_original_value(self) -> None: + event = make_event_from_dict( + {"event_id": "$1234", "unsigned": {"key": "value"}} + ) + LARGE_STRING = "a" * 100_000 + success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING) + self.assertFalse(success) + self.assertEqual(event.unsigned["key"], "value") + + +class PruneEventTestCase(stdlib_unittest.TestCase): def run_test(self, evdict, matchdict, **kwargs): """ Asserts that a new event constructed with `evdict` will look like @@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase): ) -class SerializeEventTestCase(unittest.TestCase): +class SerializeEventTestCase(stdlib_unittest.TestCase): def serialize(self, ev, fields): return serialize_event( ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields) @@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase): ) -class CopyPowerLevelsContentTestCase(unittest.TestCase): +class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def setUp(self) -> None: self.test_content = { "ban": 50, diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a433e70870..bad7f0bc60 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID from synapse.types.state import StateFilter from synapse.util import Clock -from tests.unittest import HomeserverTestCase, TestCase +from tests.unittest import HomeserverTestCase logger = logging.getLogger(__name__) @@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase): self.assertEqual(is_all, True) self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict) - - -class StateFilterDifferenceTestCase(TestCase): - def assert_difference( - self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter - ) -> None: - self.assertEqual( - minuend.approx_difference(subtrahend), - expected, - f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", - ) - - def test_state_filter_difference_no_include_other_minus_no_include_other( - self, - ) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b do not have the - include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Create: None}, include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=False, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.CanonicalAlias: {""}}, - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only a has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Create: None, - EventTypes.Member: set(), - EventTypes.CanonicalAlias: set(), - }, - include_others=True, - ), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - # This also shows that the resultant state filter is normalised. - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - StateFilter(types=frozendict(), include_others=True), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter( - types=frozendict(), - include_others=True, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.CanonicalAlias: {""}, - EventTypes.Member: set(), - }, - include_others=True, - ), - ) - - # (specific state keys) - (specific state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - ) - - def test_state_filter_difference_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), both a and b have the include_others - flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=True), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=False, - ), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - EventTypes.Create: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - EventTypes.Create: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - EventTypes.Create: {""}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: - """ - Tests the StateFilter.approx_difference method - where, in a.approx_difference(b), only b has the include_others flag set. - """ - # (wildcard on state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.Create: None}, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, - include_others=True, - ), - StateFilter(types=frozendict(), include_others=False), - ) - - # (wildcard on state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:spqr"}}, - include_others=True, - ), - StateFilter.freeze({EventTypes.Member: None}, include_others=False), - ) - - # (wildcard on state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # (specific state keys) - (wildcard on state keys): - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=True, - ), - StateFilter( - types=frozendict(), - include_others=False, - ), - ) - - # (specific state keys) - (specific state keys) - # This one is an over-approximation because we can't represent - # 'all state keys except a few named examples' - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr"}, - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@spqr:spqr"}, - }, - include_others=False, - ), - ) - - # (specific state keys) - (no state keys) - self.assert_difference( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - EventTypes.CanonicalAlias: {""}, - }, - include_others=False, - ), - StateFilter.freeze( - { - EventTypes.Member: set(), - }, - include_others=True, - ), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, - }, - include_others=False, - ), - ) - - def test_state_filter_difference_simple_cases(self) -> None: - """ - Tests some very simple cases of the StateFilter approx_difference, - that are not explicitly tested by the more in-depth tests. - """ - - self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) - - self.assert_difference( - StateFilter.all(), - StateFilter.none(), - StateFilter.all(), - ) - - -class StateFilterTestCase(TestCase): - def test_return_expanded(self) -> None: - """ - Tests the behaviour of the return_expanded() function that expands - StateFilters to include more state types (for the sake of cache hit rate). - """ - - self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) - - self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) - - # Concrete-only state filters stay the same - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": {""}, - }, - include_others=False, - ), - ) - - # Concrete-only state filters stay the same - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - {"some.other.state.type": {""}}, include_others=False - ).return_expanded(), - StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), - ) - - # Concrete-only state filters stay the same - # (Case: member-only filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - }, - include_others=False, - ), - ) - - # Wildcard member-only state filters stay the same - self.assertEqual( - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: None}, - include_others=False, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: mixed filter) - self.assertEqual( - StateFilter.freeze( - { - EventTypes.Member: {"@wombat:test", "@alicia:test"}, - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze( - {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, - include_others=True, - ), - ) - - # If there is a wildcard in the non-member portion of the filter, - # it's expanded to include ALL non-member events. - # (Case: non-member-only filter) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) - self.assertEqual( - StateFilter.freeze( - { - "some.other.state.type": None, - "yet.another.state.type": {"wombat"}, - }, - include_others=False, - ).return_expanded(), - StateFilter.freeze({EventTypes.Member: set()}, include_others=True), - ) diff --git a/tests/types/__init__.py b/tests/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/types/test_state.py b/tests/types/test_state.py new file mode 100644 index 0000000000..eb809f9fb7 --- /dev/null +++ b/tests/types/test_state.py @@ -0,0 +1,627 @@ +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types.state import StateFilter + +from tests.unittest import TestCase + + +class StateFilterDifferenceTestCase(TestCase): + def assert_difference( + self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter + ) -> None: + self.assertEqual( + minuend.approx_difference(subtrahend), + expected, + f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}", + ) + + def test_state_filter_difference_no_include_other_minus_no_include_other( + self, + ) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b do not have the + include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Create: None}, include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=False, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.CanonicalAlias: {""}}, + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_include_other_minus_no_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only a has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Create: None, + EventTypes.Member: set(), + EventTypes.CanonicalAlias: set(), + }, + include_others=True, + ), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + # This also shows that the resultant state filter is normalised. + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + StateFilter(types=frozendict(), include_others=True), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter( + types=frozendict(), + include_others=True, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.CanonicalAlias: {""}, + EventTypes.Member: set(), + }, + include_others=True, + ), + ) + + # (specific state keys) - (specific state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + ) + + def test_state_filter_difference_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), both a and b have the include_others + flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=True), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=False, + ), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + EventTypes.Create: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + EventTypes.Create: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + EventTypes.Create: {""}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_no_include_other_minus_include_other(self) -> None: + """ + Tests the StateFilter.approx_difference method + where, in a.approx_difference(b), only b has the include_others flag set. + """ + # (wildcard on state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.Create: None}, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None, EventTypes.CanonicalAlias: None}, + include_others=True, + ), + StateFilter(types=frozendict(), include_others=False), + ) + + # (wildcard on state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:spqr"}}, + include_others=True, + ), + StateFilter.freeze({EventTypes.Member: None}, include_others=False), + ) + + # (wildcard on state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # (specific state keys) - (wildcard on state keys): + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=True, + ), + StateFilter( + types=frozendict(), + include_others=False, + ), + ) + + # (specific state keys) - (specific state keys) + # This one is an over-approximation because we can't represent + # 'all state keys except a few named examples' + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr"}, + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@spqr:spqr"}, + }, + include_others=False, + ), + ) + + # (specific state keys) - (no state keys) + self.assert_difference( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + EventTypes.CanonicalAlias: {""}, + }, + include_others=False, + ), + StateFilter.freeze( + { + EventTypes.Member: set(), + }, + include_others=True, + ), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"}, + }, + include_others=False, + ), + ) + + def test_state_filter_difference_simple_cases(self) -> None: + """ + Tests some very simple cases of the StateFilter approx_difference, + that are not explicitly tested by the more in-depth tests. + """ + + self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none()) + + self.assert_difference( + StateFilter.all(), + StateFilter.none(), + StateFilter.all(), + ) + + +class StateFilterTestCase(TestCase): + def test_return_expanded(self) -> None: + """ + Tests the behaviour of the return_expanded() function that expands + StateFilters to include more state types (for the sake of cache hit rate). + """ + + self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all()) + + self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none()) + + # Concrete-only state filters stay the same + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": {""}, + }, + include_others=False, + ), + ) + + # Concrete-only state filters stay the same + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + {"some.other.state.type": {""}}, include_others=False + ).return_expanded(), + StateFilter.freeze({"some.other.state.type": {""}}, include_others=False), + ) + + # Concrete-only state filters stay the same + # (Case: member-only filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + }, + include_others=False, + ), + ) + + # Wildcard member-only state filters stay the same + self.assertEqual( + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: None}, + include_others=False, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: mixed filter) + self.assertEqual( + StateFilter.freeze( + { + EventTypes.Member: {"@wombat:test", "@alicia:test"}, + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze( + {EventTypes.Member: {"@wombat:test", "@alicia:test"}}, + include_others=True, + ), + ) + + # If there is a wildcard in the non-member portion of the filter, + # it's expanded to include ALL non-member events. + # (Case: non-member-only filter) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) + self.assertEqual( + StateFilter.freeze( + { + "some.other.state.type": None, + "yet.another.state.type": {"wombat"}, + }, + include_others=False, + ).return_expanded(), + StateFilter.freeze({EventTypes.Member: set()}, include_others=True), + ) -- cgit 1.5.1