From b5b5f6608462a988b05502a3b70b6a57ca3846d2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:19:30 +0000 Subject: Move `StateFilter` to `synapse.types` (#14668) * Move `StateFilter` to `synapse.types` * Changelog --- changelog.d/14668.misc | 1 + synapse/events/builder.py | 2 +- synapse/events/snapshot.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 2 +- synapse/handlers/message.py | 2 +- synapse/handlers/pagination.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/search.py | 2 +- synapse/handlers/sync.py | 2 +- synapse/module_api/__init__.py | 2 +- synapse/push/bulk_push_rule_evaluator.py | 2 +- synapse/push/mailer.py | 2 +- synapse/rest/admin/rooms.py | 2 +- synapse/rest/client/room.py | 2 +- synapse/state/__init__.py | 2 +- synapse/storage/controllers/persist_events.py | 2 +- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/state.py | 2 +- synapse/storage/databases/state/bg_updates.py | 2 +- synapse/storage/databases/state/store.py | 2 +- synapse/storage/state.py | 567 ---------------- synapse/types.py | 928 -------------------------- synapse/types/__init__.py | 928 ++++++++++++++++++++++++++ synapse/types/state.py | 567 ++++++++++++++++ synapse/visibility.py | 2 +- tests/storage/test_state.py | 2 +- 29 files changed, 1520 insertions(+), 1519 deletions(-) create mode 100644 changelog.d/14668.misc delete mode 100644 synapse/storage/state.py delete mode 100644 synapse/types.py create mode 100644 synapse/types/__init__.py create mode 100644 synapse/types/state.py diff --git a/changelog.d/14668.misc b/changelog.d/14668.misc new file mode 100644 index 0000000000..5269d8a97d --- /dev/null +++ b/changelog.d/14668.misc @@ -0,0 +1 @@ +Move `StateFilter` to `synapse.types`. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index d62906043f..94dd1298e1 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -28,8 +28,8 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.state import StateHandler from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import EventID, JsonDict +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.stringutils import random_string diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 1c0e96bec7..6eaef8b57a 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -23,7 +23,7 @@ from synapse.types import JsonDict, StateMap if TYPE_CHECKING: from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore - from synapse.storage.state import StateFilter + from synapse.types.state import StateFilter @attr.s(slots=True, auto_attribs=True) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3398fcaf7d..b2784d7333 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -70,8 +70,8 @@ from synapse.replication.http.federation import ( ) from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f7223b03c3..d2facdab60 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -75,7 +75,6 @@ from synapse.replication.http.federation import ( from synapse.state import StateResolutionStore from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, @@ -83,6 +82,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5cbe89f4fd..d6e90ef259 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -59,7 +59,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( MutableStateMap, PersistedEventPosition, @@ -70,6 +69,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index c572508a02..8c8ff18a1a 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse from synapse.logging.opentracing import trace from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamKeyType +from synapse.types.state import StateFilter from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6307fa9c5d..c611efb760 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -46,8 +46,8 @@ from synapse.replication.http.register import ( ReplicationRegisterServlet, ) from synapse.spam_checker_api import RegistrationBehaviour -from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester +from synapse.types.state import StateFilter if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6dcfd86fdf..f81241c2b3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -62,7 +62,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -77,6 +76,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import stringutils from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 6ad2b38b8f..0c39e852a1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -34,7 +34,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, Requester, @@ -45,6 +44,7 @@ from synapse.types import ( create_requester, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index bcab98c6d5..33115ce488 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.events import EventBase -from synapse.storage.state import StateFilter from synapse.types import JsonDict, StreamKeyType, UserID +from synapse.types.state import StateFilter from synapse.visibility import filter_events_for_client if TYPE_CHECKING: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index dace9b606f..7d6a653747 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -49,7 +49,6 @@ from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary -from synapse.storage.state import StateFilter from synapse.types import ( DeviceListUpdates, JsonDict, @@ -61,6 +60,7 @@ from synapse.types import ( StreamToken, UserID, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 96a661177a..0092a03c59 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -111,7 +111,6 @@ from synapse.storage.background_updates import ( ) from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.types import ( DomainSpecificString, JsonDict, @@ -124,6 +123,7 @@ from synapse.types import ( UserProfile, create_requester, ) +from synapse.types.state import StateFilter from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import CachedFunction, cached diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 9ed35d8461..36e5b327ef 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -35,8 +35,8 @@ from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership -from synapse.storage.state import StateFilter from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index c2575ba3d9..93b255ced5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -37,8 +37,8 @@ from synapse.push.push_types import ( TemplateVars, ) from synapse.storage.databases.main.event_push_actions import EmailPushAction -from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID +from synapse.types.state import StateFilter from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 747e6fda83..e957aa28ca 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,9 +34,9 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, RoomID, UserID, create_requester +from synapse.types.state import StateFilter from synapse.util import json_decoder if TYPE_CHECKING: diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 514eb6afc8..790614d721 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -55,9 +55,9 @@ from synapse.logging.opentracing import set_tag from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable from synapse.util.stringutils import parse_and_validate_server_name, random_string diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 833ffec3de..ee5469d5a8 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -44,8 +44,8 @@ from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import StateMap +from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 33ffef521b..f1d2c71c91 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -58,13 +58,13 @@ from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.storage.state import StateFilter from synapse.types import ( PersistedEventPosition, RoomStreamToken, StateMap, get_domain_from_id, ) +from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 2b31ce54bb..26d79c6e62 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -31,12 +31,12 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging.opentracing import tag_args, trace from synapse.storage.roommember import ProfileInfo -from synapse.storage.state import StateFilter from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.cancellation import cancellable if TYPE_CHECKING: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index af7bebee80..c801a93b5b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -33,8 +33,8 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore -from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index 4a4ad0f492..d743282f13 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -22,8 +22,8 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.storage.state import StateFilter from synapse.types import MutableStateMap, StateMap +from synapse.types.state import StateFilter from synapse.util.caches import intern_string if TYPE_CHECKING: diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index f8cfcaca83..1a7232b276 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -25,10 +25,10 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore -from synapse.storage.state import StateFilter from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator from synapse.types import MutableStateMap, StateKey, StateMap +from synapse.types.state import StateFilter from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.cancellation import cancellable diff --git a/synapse/storage/state.py b/synapse/storage/state.py deleted file mode 100644 index 0004d955b4..0000000000 --- a/synapse/storage/state.py +++ /dev/null @@ -1,567 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2022 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from typing import ( - TYPE_CHECKING, - Callable, - Collection, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - TypeVar, -) - -import attr -from frozendict import frozendict - -from synapse.api.constants import EventTypes -from synapse.types import MutableStateMap, StateKey, StateMap - -if TYPE_CHECKING: - from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - - -logger = logging.getLogger(__name__) - -# Used for generic functions below -T = TypeVar("T") - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StateFilter: - """A filter used when querying for state. - - Attributes: - types: Map from type to set of state keys (or None). This specifies - which state_keys for the given type to fetch from the DB. If None - then all events with that type are fetched. If the set is empty - then no events with that type are fetched. - include_others: Whether to fetch events with types that do not - appear in `types`. - """ - - types: "frozendict[str, Optional[FrozenSet[str]]]" - include_others: bool = False - - def __attrs_post_init__(self) -> None: - # If `include_others` is set we canonicalise the filter by removing - # wildcards from the types dictionary - if self.include_others: - # this is needed to work around the fact that StateFilter is frozen - object.__setattr__( - self, - "types", - frozendict({k: v for k, v in self.types.items() if v is not None}), - ) - - @staticmethod - def all() -> "StateFilter": - """Returns a filter that fetches everything. - - Returns: - The state filter. - """ - return _ALL_STATE_FILTER - - @staticmethod - def none() -> "StateFilter": - """Returns a filter that fetches nothing. - - Returns: - The new state filter. - """ - return _NONE_STATE_FILTER - - @staticmethod - def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": - """Creates a filter that only fetches the given types - - Args: - types: A list of type and state keys to fetch. A state_key of None - fetches everything for that type - - Returns: - The new state filter. - """ - type_dict: Dict[str, Optional[Set[str]]] = {} - for typ, s in types: - if typ in type_dict: - if type_dict[typ] is None: - continue - - if s is None: - type_dict[typ] = None - continue - - type_dict.setdefault(typ, set()).add(s) # type: ignore - - return StateFilter( - types=frozendict( - (k, frozenset(v) if v is not None else None) - for k, v in type_dict.items() - ) - ) - - @staticmethod - def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": - """Creates a filter that returns all non-member events, plus the member - events for the given users - - Args: - members: Set of user IDs - - Returns: - The new state filter - """ - return StateFilter( - types=frozendict({EventTypes.Member: frozenset(members)}), - include_others=True, - ) - - @staticmethod - def freeze( - types: Mapping[str, Optional[Collection[str]]], include_others: bool - ) -> "StateFilter": - """ - Returns a (frozen) StateFilter with the same contents as the parameters - specified here, which can be made of mutable types. - """ - types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} - for state_types, state_keys in types.items(): - if state_keys is not None: - types_with_frozen_values[state_types] = frozenset(state_keys) - else: - types_with_frozen_values[state_types] = None - - return StateFilter( - frozendict(types_with_frozen_values), include_others=include_others - ) - - def return_expanded(self) -> "StateFilter": - """Creates a new StateFilter where type wild cards have been removed - (except for memberships). The returned filter is a superset of the - current one, i.e. anything that passes the current filter will pass - the returned filter. - - This helps the caching as the DictionaryCache knows if it has *all* the - state, but does not know if it has all of the keys of a particular type, - which makes wildcard lookups expensive unless we have a complete cache. - Hence, if we are doing a wildcard lookup, populate the cache fully so - that we can do an efficient lookup next time. - - Note that since we have two caches, one for membership events and one for - other events, we can be a bit more clever than simply returning - `StateFilter.all()` if `has_wildcards()` is True. - - We return a StateFilter where: - 1. the list of membership events to return is the same - 2. if there is a wildcard that matches non-member events we - return all non-member events - - Returns: - The new state filter. - """ - - if self.is_full(): - # If we're going to return everything then there's nothing to do - return self - - if not self.has_wildcards(): - # If there are no wild cards, there's nothing to do - return self - - if EventTypes.Member in self.types: - get_all_members = self.types[EventTypes.Member] is None - else: - get_all_members = self.include_others - - has_non_member_wildcard = self.include_others or any( - state_keys is None - for t, state_keys in self.types.items() - if t != EventTypes.Member - ) - - if not has_non_member_wildcard: - # If there are no non-member wild cards we can just return ourselves - return self - - if get_all_members: - # We want to return everything. - return StateFilter.all() - elif EventTypes.Member in self.types: - # We want to return all non-members, but only particular - # memberships - return StateFilter( - types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), - include_others=True, - ) - else: - # We want to return all non-members - return _ALL_NON_MEMBER_STATE_FILTER - - def make_sql_filter_clause(self) -> Tuple[str, List[str]]: - """Converts the filter to an SQL clause. - - For example: - - f = StateFilter.from_types([("m.room.create", "")]) - clause, args = f.make_sql_filter_clause() - clause == "(type = ? AND state_key = ?)" - args == ['m.room.create', ''] - - - Returns: - The SQL string (may be empty) and arguments. An empty SQL string is - returned when the filter matches everything (i.e. is "full"). - """ - - where_clause = "" - where_args: List[str] = [] - - if self.is_full(): - return where_clause, where_args - - if not self.include_others and not self.types: - # i.e. this is an empty filter, so we need to return a clause that - # will match nothing - return "1 = 2", [] - - # First we build up a lost of clauses for each type/state_key combo - clauses = [] - for etype, state_keys in self.types.items(): - if state_keys is None: - clauses.append("(type = ?)") - where_args.append(etype) - continue - - for state_key in state_keys: - clauses.append("(type = ? AND state_key = ?)") - where_args.extend((etype, state_key)) - - # This will match anything that appears in `self.types` - where_clause = " OR ".join(clauses) - - # If we want to include stuff that's not in the types dict then we add - # a `OR type NOT IN (...)` clause to the end. - if self.include_others: - if where_clause: - where_clause += " OR " - - where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) - where_args.extend(self.types) - - return where_clause, where_args - - def max_entries_returned(self) -> Optional[int]: - """Returns the maximum number of entries this filter will return if - known, otherwise returns None. - - For example a simple state filter asking for `("m.room.create", "")` - will return 1, whereas the default state filter will return None. - - This is used to bail out early if the right number of entries have been - fetched. - """ - if self.has_wildcards(): - return None - - return len(self.concrete_types()) - - def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: - """Returns the state filtered with by this StateFilter. - - Args: - state: The state map to filter - - Returns: - The filtered state map. - This is a copy, so it's safe to mutate. - """ - if self.is_full(): - return dict(state_dict) - - filtered_state = {} - for k, v in state_dict.items(): - typ, state_key = k - if typ in self.types: - state_keys = self.types[typ] - if state_keys is None or state_key in state_keys: - filtered_state[k] = v - elif self.include_others: - filtered_state[k] = v - - return filtered_state - - def is_full(self) -> bool: - """Whether this filter fetches everything or not - - Returns: - True if the filter fetches everything. - """ - return self.include_others and not self.types - - def has_wildcards(self) -> bool: - """Whether the filter includes wildcards or is attempting to fetch - specific state. - - Returns: - True if the filter includes wildcards. - """ - - return self.include_others or any( - state_keys is None for state_keys in self.types.values() - ) - - def concrete_types(self) -> List[Tuple[str, str]]: - """Returns a list of concrete type/state_keys (i.e. not None) that - will be fetched. This will be a complete list if `has_wildcards` - returns False, but otherwise will be a subset (or even empty). - - Returns: - A list of type/state_keys tuples. - """ - return [ - (t, s) - for t, state_keys in self.types.items() - if state_keys is not None - for s in state_keys - ] - - def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: - """Return the filter split into two: one which assumes it's exclusively - matching against member state, and one which assumes it's matching - against non member state. - - This is useful due to the returned filters giving correct results for - `is_full()`, `has_wildcards()`, etc, when operating against maps that - either exclusively contain member events or only contain non-member - events. (Which is the case when dealing with the member vs non-member - state caches). - - Returns: - The member and non member filters - """ - - if EventTypes.Member in self.types: - state_keys = self.types[EventTypes.Member] - if state_keys is None: - member_filter = StateFilter.all() - else: - member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) - elif self.include_others: - member_filter = StateFilter.all() - else: - member_filter = StateFilter.none() - - non_member_filter = StateFilter( - types=frozendict( - {k: v for k, v in self.types.items() if k != EventTypes.Member} - ), - include_others=self.include_others, - ) - - return member_filter, non_member_filter - - def _decompose_into_four_parts( - self, - ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: - """ - Decomposes this state filter into 4 constituent parts, which can be - thought of as this: - all? - minus_wildcards + plus_wildcards + plus_state_keys - - where - * all represents ALL state - * minus_wildcards represents entire state types to remove - * plus_wildcards represents entire state types to add - * plus_state_keys represents individual state keys to add - - See `recompose_from_four_parts` for the other direction of this - correspondence. - """ - is_all = self.include_others - excluded_types: Set[str] = {t for t in self.types if is_all} - wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} - concrete_keys: Set[StateKey] = set(self.concrete_types()) - - return (is_all, excluded_types), (wildcard_types, concrete_keys) - - @staticmethod - def _recompose_from_four_parts( - all_part: bool, - minus_wildcards: Set[str], - plus_wildcards: Set[str], - plus_state_keys: Set[StateKey], - ) -> "StateFilter": - """ - Recomposes a state filter from 4 parts. - - See `decompose_into_four_parts` (the other direction of this - correspondence) for descriptions on each of the parts. - """ - - # {state type -> set of state keys OR None for wildcard} - # (The same structure as that of a StateFilter.) - new_types: Dict[str, Optional[Set[str]]] = {} - - # if we start with all, insert the excluded statetypes as empty sets - # to prevent them from being included - if all_part: - new_types.update({state_type: set() for state_type in minus_wildcards}) - - # insert the plus wildcards - new_types.update({state_type: None for state_type in plus_wildcards}) - - # insert the specific state keys - for state_type, state_key in plus_state_keys: - if state_type in new_types: - entry = new_types[state_type] - if entry is not None: - entry.add(state_key) - elif not all_part: - # don't insert if the entire type is already included by - # include_others as this would actually shrink the state allowed - # by this filter. - new_types[state_type] = {state_key} - - return StateFilter.freeze(new_types, include_others=all_part) - - def approx_difference(self, other: "StateFilter") -> "StateFilter": - """ - Returns a state filter which represents `self - other`. - - This is useful for determining what state remains to be pulled out of the - database if we want the state included by `self` but already have the state - included by `other`. - - The returned state filter - - MUST include all state events that are included by this filter (`self`) - unless they are included by `other`; - - MUST NOT include state events not included by this filter (`self`); and - - MAY be an over-approximation: the returned state filter - MAY additionally include some state events from `other`. - - This implementation attempts to return the narrowest such state filter. - In the case that `self` contains wildcards for state types where - `other` contains specific state keys, an approximation must be made: - the returned state filter keeps the wildcard, as state filters are not - able to express 'all state keys except some given examples'. - e.g. - StateFilter(m.room.member -> None (wildcard)) - minus - StateFilter(m.room.member -> {'@wombat:example.org'}) - is approximated as - StateFilter(m.room.member -> None (wildcard)) - """ - - # We first transform self and other into an alternative representation: - # - whether or not they include all events to begin with ('all') - # - if so, which event types are excluded? ('excludes') - # - which entire event types to include ('wildcards') - # - which concrete state keys to include ('concrete state keys') - (self_all, self_excludes), ( - self_wildcards, - self_concrete_keys, - ) = self._decompose_into_four_parts() - (other_all, other_excludes), ( - other_wildcards, - other_concrete_keys, - ) = other._decompose_into_four_parts() - - # Start with an estimate of the difference based on self - new_all = self_all - # Wildcards from the other can be added to the exclusion filter - new_excludes = self_excludes | other_wildcards - # We remove wildcards that appeared as wildcards in the other - new_wildcards = self_wildcards - other_wildcards - # We filter out the concrete state keys that appear in the other - # as wildcards or concrete state keys. - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in self_concrete_keys - if state_type not in other_wildcards - } - other_concrete_keys - - if other_all: - if self_all: - # If self starts with all, then we add as wildcards any - # types which appear in the other's exclusion filter (but - # aren't in the self exclusion filter). This is as the other - # filter will return everything BUT the types in its exclusion, so - # we need to add those excluded types that also match the self - # filter as wildcard types in the new filter. - new_wildcards |= other_excludes.difference(self_excludes) - - # If other is an `include_others` then the difference isn't. - new_all = False - # (We have no need for excludes when we don't start with all, as there - # is nothing to exclude.) - new_excludes = set() - - # We also filter out all state types that aren't in the exclusion - # list of the other. - new_wildcards &= other_excludes - new_concrete_keys = { - (state_type, state_key) - for (state_type, state_key) in new_concrete_keys - if state_type in other_excludes - } - - # Transform our newly-constructed state filter from the alternative - # representation back into the normal StateFilter representation. - return StateFilter._recompose_from_four_parts( - new_all, new_excludes, new_wildcards, new_concrete_keys - ) - - def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: - """Check if we need to wait for full state to complete to calculate this state - - If we have a state filter which is completely satisfied even with partial - state, then we don't need to await_full_state before we can return it. - - Args: - is_mine_id: a callable which confirms if a given state_key matches a mxid - of a local user - """ - # if we haven't requested membership events, then it depends on the value of - # 'include_others' - if EventTypes.Member not in self.types: - return self.include_others - - # if we're looking for *all* membership events, then we have to wait - member_state_keys = self.types[EventTypes.Member] - if member_state_keys is None: - return True - - # otherwise, consider whose membership we are looking for. If it's entirely - # local users, then we don't need to wait. - for state_key in member_state_keys: - if not is_mine_id(state_key): - # remote user - return True - - # local users only - return False - - -_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) -_ALL_NON_MEMBER_STATE_FILTER = StateFilter( - types=frozendict({EventTypes.Member: frozenset()}), include_others=True -) -_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/types.py b/synapse/types.py deleted file mode 100644 index f2d436ddc3..0000000000 --- a/synapse/types.py +++ /dev/null @@ -1,928 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import abc -import re -import string -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - List, - Mapping, - Match, - MutableMapping, - NoReturn, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) - -import attr -from frozendict import frozendict -from signedjson.key import decode_verify_key_bytes -from signedjson.types import VerifyKey -from typing_extensions import Final, TypedDict -from unpaddedbase64 import decode_base64 -from zope.interface import Interface - -from twisted.internet.defer import CancelledError -from twisted.internet.interfaces import ( - IReactorCore, - IReactorPluggableNameResolver, - IReactorSSL, - IReactorTCP, - IReactorThreads, - IReactorTime, -) - -from synapse.api.errors import Codes, SynapseError -from synapse.util.cancellation import cancellable -from synapse.util.stringutils import parse_and_validate_server_name - -if TYPE_CHECKING: - from synapse.appservice.api import ApplicationService - from synapse.storage.databases.main import DataStore, PurgeEventsStore - from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore - -# Define a state map type from type/state_key to T (usually an event ID or -# event) -T = TypeVar("T") -StateKey = Tuple[str, str] -StateMap = Mapping[StateKey, T] -MutableStateMap = MutableMapping[StateKey, T] - -# JSON types. These could be made stronger, but will do for now. -# A JSON-serialisable dict. -JsonDict = Dict[str, Any] -# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. -# Useful when you have a TypedDict which isn't going to be mutated and you don't want -# to cast to JsonDict everywhere. -JsonMapping = Mapping[str, Any] -# A JSON-serialisable object. -JsonSerializable = object - - -# Note that this seems to require inheriting *directly* from Interface in order -# for mypy-zope to realize it is an interface. -class ISynapseReactor( - IReactorTCP, - IReactorSSL, - IReactorPluggableNameResolver, - IReactorTime, - IReactorCore, - IReactorThreads, - Interface, -): - """The interfaces necessary for Synapse to function.""" - - -@attr.s(frozen=True, slots=True, auto_attribs=True) -class Requester: - """ - Represents the user making a request - - Attributes: - user: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request has been shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - """ - - user: "UserID" - access_token_id: Optional[int] - is_guest: bool - shadow_banned: bool - device_id: Optional[str] - app_service: Optional["ApplicationService"] - authenticated_entity: str - - def serialize(self) -> Dict[str, Any]: - """Converts self to a type that can be serialized as JSON, and then - deserialized by `deserialize` - - Returns: - dict - """ - return { - "user_id": self.user.to_string(), - "access_token_id": self.access_token_id, - "is_guest": self.is_guest, - "shadow_banned": self.shadow_banned, - "device_id": self.device_id, - "app_server_id": self.app_service.id if self.app_service else None, - "authenticated_entity": self.authenticated_entity, - } - - @staticmethod - def deserialize( - store: "ApplicationServiceWorkerStore", input: Dict[str, Any] - ) -> "Requester": - """Converts a dict that was produced by `serialize` back into a - Requester. - - Args: - store: Used to convert AS ID to AS object - input: A dict produced by `serialize` - - Returns: - Requester - """ - appservice = None - if input["app_server_id"]: - appservice = store.get_app_service_by_id(input["app_server_id"]) - - return Requester( - user=UserID.from_string(input["user_id"]), - access_token_id=input["access_token_id"], - is_guest=input["is_guest"], - shadow_banned=input["shadow_banned"], - device_id=input["device_id"], - app_service=appservice, - authenticated_entity=input["authenticated_entity"], - ) - - -def create_requester( - user_id: Union[str, "UserID"], - access_token_id: Optional[int] = None, - is_guest: bool = False, - shadow_banned: bool = False, - device_id: Optional[str] = None, - app_service: Optional["ApplicationService"] = None, - authenticated_entity: Optional[str] = None, -) -> Requester: - """ - Create a new ``Requester`` object - - Args: - user_id: id of the user making the request - access_token_id: *ID* of the access token used for this - request, or None if it came via the appservice API or similar - is_guest: True if the user making this request is a guest user - shadow_banned: True if the user making this request is shadow-banned. - device_id: device_id which was set at authentication time - app_service: the AS requesting on behalf of the user - authenticated_entity: The entity that authenticated when making the request. - This is different to the user_id when an admin user or the server is - "puppeting" the user. - - Returns: - Requester - """ - if not isinstance(user_id, UserID): - user_id = UserID.from_string(user_id) - - if authenticated_entity is None: - authenticated_entity = user_id.to_string() - - return Requester( - user_id, - access_token_id, - is_guest, - shadow_banned, - device_id, - app_service, - authenticated_entity, - ) - - -def get_domain_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[idx + 1 :] - - -def get_localpart_from_id(string: str) -> str: - idx = string.find(":") - if idx == -1: - raise SynapseError(400, "Invalid ID: %r" % (string,)) - return string[1:idx] - - -DS = TypeVar("DS", bound="DomainSpecificString") - - -@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) -class DomainSpecificString(metaclass=abc.ABCMeta): - """Common base class among ID/name strings that have a local part and a - domain name, prefixed with a sigil. - - Has the fields: - - 'localpart' : The local part of the name (without the leading sigil) - 'domain' : The domain part of the name - """ - - SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore - - localpart: str - domain: str - - # Because this is a frozen class, it is deeply immutable. - def __copy__(self: DS) -> DS: - return self - - def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: - return self - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - """Parse the string given by 's' into a structure object.""" - if len(s) < 1 or s[0:1] != cls.SIGIL: - raise SynapseError( - 400, - "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - parts = s[1:].split(":", 1) - if len(parts) != 2: - raise SynapseError( - 400, - "Expected %s of the form '%slocalname:domain'" - % (cls.__name__, cls.SIGIL), - Codes.INVALID_PARAM, - ) - - domain = parts[1] - # This code will need changing if we want to support multiple domain - # names on one HS - return cls(localpart=parts[0], domain=domain) - - def to_string(self) -> str: - """Return a string encoding the fields of the structure object.""" - return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) - - @classmethod - def is_valid(cls: Type[DS], s: str) -> bool: - """Parses the input string and attempts to ensure it is valid.""" - # TODO: this does not reject an empty localpart or an overly-long string. - # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar - try: - obj = cls.from_string(s) - # Apply additional validation to the domain. This is only done - # during is_valid (and not part of from_string) since it is - # possible for invalid data to exist in room-state, etc. - parse_and_validate_server_name(obj.domain) - return True - except Exception: - return False - - __repr__ = to_string - - -@attr.s(slots=True, frozen=True, repr=False) -class UserID(DomainSpecificString): - """Structure representing a user ID.""" - - SIGIL = "@" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomAlias(DomainSpecificString): - """Structure representing a room name.""" - - SIGIL = "#" - - -@attr.s(slots=True, frozen=True, repr=False) -class RoomID(DomainSpecificString): - """Structure representing a room id.""" - - SIGIL = "!" - - -@attr.s(slots=True, frozen=True, repr=False) -class EventID(DomainSpecificString): - """Structure representing an event id.""" - - SIGIL = "$" - - -mxid_localpart_allowed_characters = set( - "_-./=" + string.ascii_lowercase + string.digits -) - - -def contains_invalid_mxid_characters(localpart: str) -> bool: - """Check for characters not allowed in an mxid or groupid localpart - - Args: - localpart: the localpart to be checked - - Returns: - True if there are any naughty characters - """ - return any(c not in mxid_localpart_allowed_characters for c in localpart) - - -UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") - -# the following is a pattern which matches '=', and bytes which are not allowed in a mxid -# localpart. -# -# It works by: -# * building a string containing the allowed characters (excluding '=') -# * escaping every special character with a backslash (to stop '-' being interpreted as a -# range operator) -# * wrapping it in a '[^...]' regex -# * converting the whole lot to a 'bytes' sequence, so that we can use it to match -# bytes rather than strings -# -NON_MXID_CHARACTER_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( - "ascii" - ) -) - - -def map_username_to_mxid_localpart( - username: Union[str, bytes], case_sensitive: bool = False -) -> str: - """Map a username onto a string suitable for a MXID - - This follows the algorithm laid out at - https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. - - Args: - username: username to be mapped - case_sensitive: true if TEST and test should be mapped - onto different mxids - - Returns: - string suitable for a mxid localpart - """ - if not isinstance(username, bytes): - username = username.encode("utf-8") - - # first we sort out upper-case characters - if case_sensitive: - - def f1(m: Match[bytes]) -> bytes: - return b"_" + m.group().lower() - - username = UPPER_CASE_PATTERN.sub(f1, username) - else: - username = username.lower() - - # then we sort out non-ascii characters by converting to the hex equivalent. - def f2(m: Match[bytes]) -> bytes: - return b"=%02x" % (m.group()[0],) - - username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) - - # we also do the =-escaping to mxids starting with an underscore. - username = re.sub(b"^_", b"=5f", username) - - # we should now only have ascii bytes left, so can decode back to a string. - return username.decode("ascii") - - -@attr.s(frozen=True, slots=True, order=False) -class RoomStreamToken: - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] ▼ [1] ▼ [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream, events are ordered by - `stream_ordering` (when they arrived at the homeserver). - - When traversing historic events, events are first ordered by their `depth` - (`topological_ordering` in the event graph) and tie-broken by - `stream_ordering` (when the event arrived at the homeserver). - - If you're looking for more info about what a token with all of the - underscores means, ex. - `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring - for `StreamToken` below. - - --- - - Live tokens start with an "s" followed by the `stream_ordering` of the event - that comes before the position of the token. Said another way: - `stream_ordering` uniquely identifies a persisted event. The live token - means "the position just after the event identified by `stream_ordering`". - An example token is: - - s2633508 - - --- - - Historic tokens start with a "t" followed by the `depth` - (`topological_ordering` in the event graph) of the event that comes before - the position of the token, followed by "-", followed by the - `stream_ordering` of the event that comes before the position of the token. - An example token is: - - t426-2633508 - - --- - - There is also a third mode for live tokens where the token starts with "m", - which is sometimes used when using sharded event persisters. In this case - the events stream is considered to be a set of streams (one for each writer) - and the token encodes the vector clock of positions of each writer in their - respective streams. - - The format of the token in such case is an initial integer min position, - followed by the mapping of instance ID to position separated by '.' and '~': - - m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... - - The `min_pos` corresponds to the minimum position all writers have persisted - up to, and then only writers that are ahead of that position need to be - encoded. An example token is: - - m56~2.58~3.59 - - Which corresponds to a set of three (or more writers) where instances 2 and - 3 (these are instance IDs that can be looked up in the DB to fetch the more - commonly used instance names) are at positions 58 and 59 respectively, and - all other instances are at position 56. - - Note: The `RoomStreamToken` cannot have both a topological part and an - instance map. - - --- - - For caching purposes, `RoomStreamToken`s and by extension, all their - attributes, must be hashable. - """ - - topological: Optional[int] = attr.ib( - validator=attr.validators.optional(attr.validators.instance_of(int)), - ) - stream: int = attr.ib(validator=attr.validators.instance_of(int)) - - instance_map: "frozendict[str, int]" = attr.ib( - factory=frozendict, - validator=attr.validators.deep_mapping( - key_validator=attr.validators.instance_of(str), - value_validator=attr.validators.instance_of(int), - mapping_validator=attr.validators.instance_of(frozendict), - ), - ) - - def __attrs_post_init__(self) -> None: - """Validates that both `topological` and `instance_map` aren't set.""" - - if self.instance_map and self.topological: - raise ValueError( - "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." - ) - - @classmethod - async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - if string[0] == "t": - parts = string[1:].split("-", 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - if string[0] == "m": - parts = string[1:].split("~") - stream = int(parts[0]) - - instance_map = {} - for part in parts[1:]: - key, value = part.split(".") - instance_id = int(key) - pos = int(value) - - instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] - instance_map[instance_name] = pos - - return cls( - topological=None, - stream=stream, - instance_map=frozendict(instance_map), - ) - except CancelledError: - raise - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - @classmethod - def parse_stream_token(cls, string: str) -> "RoomStreamToken": - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - except Exception: - pass - raise SynapseError(400, "Invalid room stream token %r" % (string,)) - - def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": - """Return a new token such that if an event is after both this token and - the other token, then its after the returned token too. - """ - - if self.topological or other.topological: - raise Exception("Can't advance topological tokens") - - max_stream = max(self.stream, other.stream) - - instance_map = { - instance: max( - self.instance_map.get(instance, self.stream), - other.instance_map.get(instance, other.stream), - ) - for instance in set(self.instance_map).union(other.instance_map) - } - - return RoomStreamToken(None, max_stream, frozendict(instance_map)) - - def as_historical_tuple(self) -> Tuple[int, int]: - """Returns a tuple of `(topological, stream)` for historical tokens. - - Raises if not an historical token (i.e. doesn't have a topological part). - """ - if self.topological is None: - raise Exception( - "Cannot call `RoomStreamToken.as_historical_tuple` on live token" - ) - - return self.topological, self.stream - - def get_stream_pos_for_instance(self, instance_name: str) -> int: - """Get the stream position that the given writer was at at this token. - - This only makes sense for "live" tokens that may have a vector clock - component, and so asserts that this is a "live" token. - """ - assert self.topological is None - - # If we don't have an entry for the instance we can assume that it was - # at `self.stream`. - return self.instance_map.get(instance_name, self.stream) - - def get_max_stream_pos(self) -> int: - """Get the maximum stream position referenced in this token. - - The corresponding "min" position is, by definition just `self.stream`. - - This is used to handle tokens that have non-empty `instance_map`, and so - reference stream positions after the `self.stream` position. - """ - return max(self.instance_map.values(), default=self.stream) - - async def to_string(self, store: "DataStore") -> str: - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - elif self.instance_map: - entries = [] - for name, pos in self.instance_map.items(): - instance_id = await store.get_id_for_instance(name) - entries.append(f"{instance_id}.{pos}") - - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" - else: - return "s%d" % (self.stream,) - - -class StreamKeyType: - """Known stream types. - - A stream is a list of entities ordered by an incrementing "stream token". - """ - - ROOM: Final = "room_key" - PRESENCE: Final = "presence_key" - TYPING: Final = "typing_key" - RECEIPT: Final = "receipt_key" - ACCOUNT_DATA: Final = "account_data_key" - PUSH_RULES: Final = "push_rules_key" - TO_DEVICE: Final = "to_device_key" - DEVICE_LIST: Final = "device_list_key" - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class StreamToken: - """A collection of keys joined together by underscores in the following - order and which represent the position in their respective streams. - - ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` - 1. `room_key`: `s2633508` which is a `RoomStreamToken` - - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` - - See the docstring for `RoomStreamToken` for more details. - 2. `presence_key`: `17` - 3. `typing_key`: `338` - 4. `receipt_key`: `6732159` - 5. `account_data_key`: `1082514` - 6. `push_rules_key`: `541479` - 7. `to_device_key`: `274711` - 8. `device_list_key`: `265584` - 9. `groups_key`: `1` (note that this key is now unused) - - You can see how many of these keys correspond to the various - fields in a "/sync" response: - ```json - { - "next_batch": "s12_4_0_1_1_1_1_4_1", - "presence": { - "events": [] - }, - "device_lists": { - "changed": [] - }, - "rooms": { - "join": { - "!QrZlfIDQLNLdZHqTnt:hs1": { - "timeline": { - "events": [], - "prev_batch": "s10_4_0_1_1_1_1_4_1", - "limited": false - }, - "state": { - "events": [] - }, - "account_data": { - "events": [] - }, - "ephemeral": { - "events": [] - } - } - } - } - } - ``` - - --- - - For caching purposes, `StreamToken`s and by extension, all their attributes, - must be hashable. - """ - - room_key: RoomStreamToken = attr.ib( - validator=attr.validators.instance_of(RoomStreamToken) - ) - presence_key: int - typing_key: int - receipt_key: int - account_data_key: int - push_rules_key: int - to_device_key: int - device_list_key: int - # Note that the groups key is no longer used and may have bogus values. - groups_key: int - - _SEPARATOR = "_" - START: ClassVar["StreamToken"] - - @classmethod - @cancellable - async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": - """ - Creates a RoomStreamToken from its textual representation. - """ - try: - keys = string.split(cls._SEPARATOR) - while len(keys) < len(attr.fields(cls)): - # i.e. old token from before receipt_key - keys.append("0") - return cls( - await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) - ) - except CancelledError: - raise - except Exception: - raise SynapseError(400, "Invalid stream token") - - async def to_string(self, store: "DataStore") -> str: - return self._SEPARATOR.join( - [ - await self.room_key.to_string(store), - str(self.presence_key), - str(self.typing_key), - str(self.receipt_key), - str(self.account_data_key), - str(self.push_rules_key), - str(self.to_device_key), - str(self.device_list_key), - # Note that the groups key is no longer used, but it is still - # serialized so that there will not be confusion in the future - # if additional tokens are added. - str(self.groups_key), - ] - ) - - @property - def room_stream_id(self) -> int: - return self.room_key.stream - - def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": - """Advance the given key in the token to a new value if and only if the - new value is after the old value. - - :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. - """ - if key == StreamKeyType.ROOM: - new_token = self.copy_and_replace( - StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) - ) - return new_token - - new_token = self.copy_and_replace(key, new_value) - new_id = int(getattr(new_token, key)) - old_id = int(getattr(self, key)) - - if old_id < new_id: - return new_token - else: - return self - - def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": - return attr.evolve(self, **{key: new_value}) - - -StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class PersistedEventPosition: - """Position of a newly persisted event with instance that persisted it. - - This can be used to test whether the event is persisted before or after a - RoomStreamToken. - """ - - instance_name: str - stream: int - - def persisted_after(self, token: RoomStreamToken) -> bool: - return token.get_stream_pos_for_instance(self.instance_name) < self.stream - - def to_room_stream_token(self) -> RoomStreamToken: - """Converts the position to a room stream token such that events - persisted in the same room after this position will be after the - returned `RoomStreamToken`. - - Note: no guarantees are made about ordering w.r.t. events in other - rooms. - """ - # Doing the naive thing satisfies the desired properties described in - # the docstring. - return RoomStreamToken(None, self.stream) - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ThirdPartyInstanceID: - appservice_id: Optional[str] - network_id: Optional[str] - - # Deny iteration because it will bite you if you try to create a singleton - # set by: - # users = set(user) - def __iter__(self) -> NoReturn: - raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) - - # Because this class is a frozen class, it is deeply immutable. - def __copy__(self) -> "ThirdPartyInstanceID": - return self - - def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": - return self - - @classmethod - def from_string(cls, s: str) -> "ThirdPartyInstanceID": - bits = s.split("|", 2) - if len(bits) != 2: - raise SynapseError(400, "Invalid ID %r" % (s,)) - - return cls(appservice_id=bits[0], network_id=bits[1]) - - def to_string(self) -> str: - return "%s|%s" % (self.appservice_id, self.network_id) - - __str__ = to_string - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class ReadReceipt: - """Information about a read-receipt""" - - room_id: str - receipt_type: str - user_id: str - event_ids: List[str] - thread_id: Optional[str] - data: JsonDict - - -@attr.s(slots=True, frozen=True, auto_attribs=True) -class DeviceListUpdates: - """ - An object containing a diff of information regarding other users' device lists, intended for - a recipient to carry out device list tracking. - - Attributes: - changed: A set of users whose device lists have changed recently. - left: A set of users who the recipient no longer needs to track the device lists of. - Typically when those users no longer share any end-to-end encryption enabled rooms. - """ - - # We need to use a factory here, otherwise `set` is not evaluated at - # object instantiation, but instead at class definition instantiation. - # The latter happening only once, thus always giving you the same sets - # across multiple DeviceListUpdates instances. - # Also see: don't define mutable default arguments. - changed: Set[str] = attr.ib(factory=set) - left: Set[str] = attr.ib(factory=set) - - def __bool__(self) -> bool: - return bool(self.changed or self.left) - - -def get_verify_key_from_cross_signing_key( - key_info: Mapping[str, Any] -) -> Tuple[str, VerifyKey]: - """Get the key ID and signedjson verify key from a cross-signing key dict - - Args: - key_info: a cross-signing key dict, which must have a "keys" - property that has exactly one item in it - - Returns: - the key ID and verify key for the cross-signing key - """ - # make sure that a `keys` field is provided - if "keys" not in key_info: - raise ValueError("Invalid key") - keys = key_info["keys"] - # and that it contains exactly one key - if len(keys) == 1: - key_id, key_data = next(iter(keys.items())) - return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) - else: - raise ValueError("Invalid key") - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class UserInfo: - """Holds information about a user. Result of get_userinfo_by_id. - - Attributes: - user_id: ID of the user. - appservice_id: Application service ID that created this user. - consent_server_notice_sent: Version of policy documents the user has been sent. - consent_version: Version of policy documents the user has consented to. - creation_ts: Creation timestamp of the user. - is_admin: True if the user is an admin. - is_deactivated: True if the user has been deactivated. - is_guest: True if the user is a guest user. - is_shadow_banned: True if the user has been shadow-banned. - user_type: User type (None for normal user, 'support' and 'bot' other options). - """ - - user_id: UserID - appservice_id: Optional[int] - consent_server_notice_sent: Optional[str] - consent_version: Optional[str] - user_type: Optional[str] - creation_ts: int - is_admin: bool - is_deactivated: bool - is_guest: bool - is_shadow_banned: bool - - -class UserProfile(TypedDict): - user_id: str - display_name: Optional[str] - avatar_url: Optional[str] - - -@attr.s(auto_attribs=True, frozen=True, slots=True) -class RetentionPolicy: - min_lifetime: Optional[int] = None - max_lifetime: Optional[int] = None diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py new file mode 100644 index 0000000000..f2d436ddc3 --- /dev/null +++ b/synapse/types/__init__.py @@ -0,0 +1,928 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import re +import string +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Mapping, + Match, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import attr +from frozendict import frozendict +from signedjson.key import decode_verify_key_bytes +from signedjson.types import VerifyKey +from typing_extensions import Final, TypedDict +from unpaddedbase64 import decode_base64 +from zope.interface import Interface + +from twisted.internet.defer import CancelledError +from twisted.internet.interfaces import ( + IReactorCore, + IReactorPluggableNameResolver, + IReactorSSL, + IReactorTCP, + IReactorThreads, + IReactorTime, +) + +from synapse.api.errors import Codes, SynapseError +from synapse.util.cancellation import cancellable +from synapse.util.stringutils import parse_and_validate_server_name + +if TYPE_CHECKING: + from synapse.appservice.api import ApplicationService + from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateKey = Tuple[str, str] +StateMap = Mapping[StateKey, T] +MutableStateMap = MutableMapping[StateKey, T] + +# JSON types. These could be made stronger, but will do for now. +# A JSON-serialisable dict. +JsonDict = Dict[str, Any] +# A JSON-serialisable mapping; roughly speaking an immutable JSONDict. +# Useful when you have a TypedDict which isn't going to be mutated and you don't want +# to cast to JsonDict everywhere. +JsonMapping = Mapping[str, Any] +# A JSON-serialisable object. +JsonSerializable = object + + +# Note that this seems to require inheriting *directly* from Interface in order +# for mypy-zope to realize it is an interface. +class ISynapseReactor( + IReactorTCP, + IReactorSSL, + IReactorPluggableNameResolver, + IReactorTime, + IReactorCore, + IReactorThreads, + Interface, +): + """The interfaces necessary for Synapse to function.""" + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class Requester: + """ + Represents the user making a request + + Attributes: + user: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request has been shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + """ + + user: "UserID" + access_token_id: Optional[int] + is_guest: bool + shadow_banned: bool + device_id: Optional[str] + app_service: Optional["ApplicationService"] + authenticated_entity: str + + def serialize(self) -> Dict[str, Any]: + """Converts self to a type that can be serialized as JSON, and then + deserialized by `deserialize` + + Returns: + dict + """ + return { + "user_id": self.user.to_string(), + "access_token_id": self.access_token_id, + "is_guest": self.is_guest, + "shadow_banned": self.shadow_banned, + "device_id": self.device_id, + "app_server_id": self.app_service.id if self.app_service else None, + "authenticated_entity": self.authenticated_entity, + } + + @staticmethod + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": + """Converts a dict that was produced by `serialize` back into a + Requester. + + Args: + store: Used to convert AS ID to AS object + input: A dict produced by `serialize` + + Returns: + Requester + """ + appservice = None + if input["app_server_id"]: + appservice = store.get_app_service_by_id(input["app_server_id"]) + + return Requester( + user=UserID.from_string(input["user_id"]), + access_token_id=input["access_token_id"], + is_guest=input["is_guest"], + shadow_banned=input["shadow_banned"], + device_id=input["device_id"], + app_service=appservice, + authenticated_entity=input["authenticated_entity"], + ) + + +def create_requester( + user_id: Union[str, "UserID"], + access_token_id: Optional[int] = None, + is_guest: bool = False, + shadow_banned: bool = False, + device_id: Optional[str] = None, + app_service: Optional["ApplicationService"] = None, + authenticated_entity: Optional[str] = None, +) -> Requester: + """ + Create a new ``Requester`` object + + Args: + user_id: id of the user making the request + access_token_id: *ID* of the access token used for this + request, or None if it came via the appservice API or similar + is_guest: True if the user making this request is a guest user + shadow_banned: True if the user making this request is shadow-banned. + device_id: device_id which was set at authentication time + app_service: the AS requesting on behalf of the user + authenticated_entity: The entity that authenticated when making the request. + This is different to the user_id when an admin user or the server is + "puppeting" the user. + + Returns: + Requester + """ + if not isinstance(user_id, UserID): + user_id = UserID.from_string(user_id) + + if authenticated_entity is None: + authenticated_entity = user_id.to_string() + + return Requester( + user_id, + access_token_id, + is_guest, + shadow_banned, + device_id, + app_service, + authenticated_entity, + ) + + +def get_domain_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[idx + 1 :] + + +def get_localpart_from_id(string: str) -> str: + idx = string.find(":") + if idx == -1: + raise SynapseError(400, "Invalid ID: %r" % (string,)) + return string[1:idx] + + +DS = TypeVar("DS", bound="DomainSpecificString") + + +@attr.s(slots=True, frozen=True, repr=False, auto_attribs=True) +class DomainSpecificString(metaclass=abc.ABCMeta): + """Common base class among ID/name strings that have a local part and a + domain name, prefixed with a sigil. + + Has the fields: + + 'localpart' : The local part of the name (without the leading sigil) + 'domain' : The domain part of the name + """ + + SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore + + localpart: str + domain: str + + # Because this is a frozen class, it is deeply immutable. + def __copy__(self: DS) -> DS: + return self + + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: + return self + + @classmethod + def from_string(cls: Type[DS], s: str) -> DS: + """Parse the string given by 's' into a structure object.""" + if len(s) < 1 or s[0:1] != cls.SIGIL: + raise SynapseError( + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + parts = s[1:].split(":", 1) + if len(parts) != 2: + raise SynapseError( + 400, + "Expected %s of the form '%slocalname:domain'" + % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + domain = parts[1] + # This code will need changing if we want to support multiple domain + # names on one HS + return cls(localpart=parts[0], domain=domain) + + def to_string(self) -> str: + """Return a string encoding the fields of the structure object.""" + return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) + + @classmethod + def is_valid(cls: Type[DS], s: str) -> bool: + """Parses the input string and attempts to ensure it is valid.""" + # TODO: this does not reject an empty localpart or an overly-long string. + # See https://spec.matrix.org/v1.2/appendices/#identifier-grammar + try: + obj = cls.from_string(s) + # Apply additional validation to the domain. This is only done + # during is_valid (and not part of from_string) since it is + # possible for invalid data to exist in room-state, etc. + parse_and_validate_server_name(obj.domain) + return True + except Exception: + return False + + __repr__ = to_string + + +@attr.s(slots=True, frozen=True, repr=False) +class UserID(DomainSpecificString): + """Structure representing a user ID.""" + + SIGIL = "@" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomAlias(DomainSpecificString): + """Structure representing a room name.""" + + SIGIL = "#" + + +@attr.s(slots=True, frozen=True, repr=False) +class RoomID(DomainSpecificString): + """Structure representing a room id.""" + + SIGIL = "!" + + +@attr.s(slots=True, frozen=True, repr=False) +class EventID(DomainSpecificString): + """Structure representing an event id.""" + + SIGIL = "$" + + +mxid_localpart_allowed_characters = set( + "_-./=" + string.ascii_lowercase + string.digits +) + + +def contains_invalid_mxid_characters(localpart: str) -> bool: + """Check for characters not allowed in an mxid or groupid localpart + + Args: + localpart: the localpart to be checked + + Returns: + True if there are any naughty characters + """ + return any(c not in mxid_localpart_allowed_characters for c in localpart) + + +UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") + +# the following is a pattern which matches '=', and bytes which are not allowed in a mxid +# localpart. +# +# It works by: +# * building a string containing the allowed characters (excluding '=') +# * escaping every special character with a backslash (to stop '-' being interpreted as a +# range operator) +# * wrapping it in a '[^...]' regex +# * converting the whole lot to a 'bytes' sequence, so that we can use it to match +# bytes rather than strings +# +NON_MXID_CHARACTER_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode( + "ascii" + ) +) + + +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: + """Map a username onto a string suitable for a MXID + + This follows the algorithm laid out at + https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. + + Args: + username: username to be mapped + case_sensitive: true if TEST and test should be mapped + onto different mxids + + Returns: + string suitable for a mxid localpart + """ + if not isinstance(username, bytes): + username = username.encode("utf-8") + + # first we sort out upper-case characters + if case_sensitive: + + def f1(m: Match[bytes]) -> bytes: + return b"_" + m.group().lower() + + username = UPPER_CASE_PATTERN.sub(f1, username) + else: + username = username.lower() + + # then we sort out non-ascii characters by converting to the hex equivalent. + def f2(m: Match[bytes]) -> bytes: + return b"=%02x" % (m.group()[0],) + + username = NON_MXID_CHARACTER_PATTERN.sub(f2, username) + + # we also do the =-escaping to mxids starting with an underscore. + username = re.sub(b"^_", b"=5f", username) + + # we should now only have ascii bytes left, so can decode back to a string. + return username.decode("ascii") + + +@attr.s(frozen=True, slots=True, order=False) +class RoomStreamToken: + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] ▼ [1] ▼ [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream, events are ordered by + `stream_ordering` (when they arrived at the homeserver). + + When traversing historic events, events are first ordered by their `depth` + (`topological_ordering` in the event graph) and tie-broken by + `stream_ordering` (when the event arrived at the homeserver). + + If you're looking for more info about what a token with all of the + underscores means, ex. + `s2633508_17_338_6732159_1082514_541479_274711_265584_1`, see the docstring + for `StreamToken` below. + + --- + + Live tokens start with an "s" followed by the `stream_ordering` of the event + that comes before the position of the token. Said another way: + `stream_ordering` uniquely identifies a persisted event. The live token + means "the position just after the event identified by `stream_ordering`". + An example token is: + + s2633508 + + --- + + Historic tokens start with a "t" followed by the `depth` + (`topological_ordering` in the event graph) of the event that comes before + the position of the token, followed by "-", followed by the + `stream_ordering` of the event that comes before the position of the token. + An example token is: + + t426-2633508 + + --- + + There is also a third mode for live tokens where the token starts with "m", + which is sometimes used when using sharded event persisters. In this case + the events stream is considered to be a set of streams (one for each writer) + and the token encodes the vector clock of positions of each writer in their + respective streams. + + The format of the token in such case is an initial integer min position, + followed by the mapping of instance ID to position separated by '.' and '~': + + m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ... + + The `min_pos` corresponds to the minimum position all writers have persisted + up to, and then only writers that are ahead of that position need to be + encoded. An example token is: + + m56~2.58~3.59 + + Which corresponds to a set of three (or more writers) where instances 2 and + 3 (these are instance IDs that can be looked up in the DB to fetch the more + commonly used instance names) are at positions 58 and 59 respectively, and + all other instances are at position 56. + + Note: The `RoomStreamToken` cannot have both a topological part and an + instance map. + + --- + + For caching purposes, `RoomStreamToken`s and by extension, all their + attributes, must be hashable. + """ + + topological: Optional[int] = attr.ib( + validator=attr.validators.optional(attr.validators.instance_of(int)), + ) + stream: int = attr.ib(validator=attr.validators.instance_of(int)) + + instance_map: "frozendict[str, int]" = attr.ib( + factory=frozendict, + validator=attr.validators.deep_mapping( + key_validator=attr.validators.instance_of(str), + value_validator=attr.validators.instance_of(int), + mapping_validator=attr.validators.instance_of(frozendict), + ), + ) + + def __attrs_post_init__(self) -> None: + """Validates that both `topological` and `instance_map` aren't set.""" + + if self.instance_map and self.topological: + raise ValueError( + "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." + ) + + @classmethod + async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + if string[0] == "t": + parts = string[1:].split("-", 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + if string[0] == "m": + parts = string[1:].split("~") + stream = int(parts[0]) + + instance_map = {} + for part in parts[1:]: + key, value = part.split(".") + instance_id = int(key) + pos = int(value) + + instance_name = await store.get_name_from_instance_id(instance_id) # type: ignore[attr-defined] + instance_map[instance_name] = pos + + return cls( + topological=None, + stream=stream, + instance_map=frozendict(instance_map), + ) + except CancelledError: + raise + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + except Exception: + pass + raise SynapseError(400, "Invalid room stream token %r" % (string,)) + + def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken": + """Return a new token such that if an event is after both this token and + the other token, then its after the returned token too. + """ + + if self.topological or other.topological: + raise Exception("Can't advance topological tokens") + + max_stream = max(self.stream, other.stream) + + instance_map = { + instance: max( + self.instance_map.get(instance, self.stream), + other.instance_map.get(instance, other.stream), + ) + for instance in set(self.instance_map).union(other.instance_map) + } + + return RoomStreamToken(None, max_stream, frozendict(instance_map)) + + def as_historical_tuple(self) -> Tuple[int, int]: + """Returns a tuple of `(topological, stream)` for historical tokens. + + Raises if not an historical token (i.e. doesn't have a topological part). + """ + if self.topological is None: + raise Exception( + "Cannot call `RoomStreamToken.as_historical_tuple` on live token" + ) + + return self.topological, self.stream + + def get_stream_pos_for_instance(self, instance_name: str) -> int: + """Get the stream position that the given writer was at at this token. + + This only makes sense for "live" tokens that may have a vector clock + component, and so asserts that this is a "live" token. + """ + assert self.topological is None + + # If we don't have an entry for the instance we can assume that it was + # at `self.stream`. + return self.instance_map.get(instance_name, self.stream) + + def get_max_stream_pos(self) -> int: + """Get the maximum stream position referenced in this token. + + The corresponding "min" position is, by definition just `self.stream`. + + This is used to handle tokens that have non-empty `instance_map`, and so + reference stream positions after the `self.stream` position. + """ + return max(self.instance_map.values(), default=self.stream) + + async def to_string(self, store: "DataStore") -> str: + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + elif self.instance_map: + entries = [] + for name, pos in self.instance_map.items(): + instance_id = await store.get_id_for_instance(name) + entries.append(f"{instance_id}.{pos}") + + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + else: + return "s%d" % (self.stream,) + + +class StreamKeyType: + """Known stream types. + + A stream is a list of entities ordered by an incrementing "stream token". + """ + + ROOM: Final = "room_key" + PRESENCE: Final = "presence_key" + TYPING: Final = "typing_key" + RECEIPT: Final = "receipt_key" + ACCOUNT_DATA: Final = "account_data_key" + PUSH_RULES: Final = "push_rules_key" + TO_DEVICE: Final = "to_device_key" + DEVICE_LIST: Final = "device_list_key" + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StreamToken: + """A collection of keys joined together by underscores in the following + order and which represent the position in their respective streams. + + ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1` + 1. `room_key`: `s2633508` which is a `RoomStreamToken` + - `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59` + - See the docstring for `RoomStreamToken` for more details. + 2. `presence_key`: `17` + 3. `typing_key`: `338` + 4. `receipt_key`: `6732159` + 5. `account_data_key`: `1082514` + 6. `push_rules_key`: `541479` + 7. `to_device_key`: `274711` + 8. `device_list_key`: `265584` + 9. `groups_key`: `1` (note that this key is now unused) + + You can see how many of these keys correspond to the various + fields in a "/sync" response: + ```json + { + "next_batch": "s12_4_0_1_1_1_1_4_1", + "presence": { + "events": [] + }, + "device_lists": { + "changed": [] + }, + "rooms": { + "join": { + "!QrZlfIDQLNLdZHqTnt:hs1": { + "timeline": { + "events": [], + "prev_batch": "s10_4_0_1_1_1_1_4_1", + "limited": false + }, + "state": { + "events": [] + }, + "account_data": { + "events": [] + }, + "ephemeral": { + "events": [] + } + } + } + } + } + ``` + + --- + + For caching purposes, `StreamToken`s and by extension, all their attributes, + must be hashable. + """ + + room_key: RoomStreamToken = attr.ib( + validator=attr.validators.instance_of(RoomStreamToken) + ) + presence_key: int + typing_key: int + receipt_key: int + account_data_key: int + push_rules_key: int + to_device_key: int + device_list_key: int + # Note that the groups key is no longer used and may have bogus values. + groups_key: int + + _SEPARATOR = "_" + START: ClassVar["StreamToken"] + + @classmethod + @cancellable + async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": + """ + Creates a RoomStreamToken from its textual representation. + """ + try: + keys = string.split(cls._SEPARATOR) + while len(keys) < len(attr.fields(cls)): + # i.e. old token from before receipt_key + keys.append("0") + return cls( + await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:]) + ) + except CancelledError: + raise + except Exception: + raise SynapseError(400, "Invalid stream token") + + async def to_string(self, store: "DataStore") -> str: + return self._SEPARATOR.join( + [ + await self.room_key.to_string(store), + str(self.presence_key), + str(self.typing_key), + str(self.receipt_key), + str(self.account_data_key), + str(self.push_rules_key), + str(self.to_device_key), + str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. + str(self.groups_key), + ] + ) + + @property + def room_stream_id(self) -> int: + return self.room_key.stream + + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": + """Advance the given key in the token to a new value if and only if the + new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. + """ + if key == StreamKeyType.ROOM: + new_token = self.copy_and_replace( + StreamKeyType.ROOM, self.room_key.copy_and_advance(new_value) + ) + return new_token + + new_token = self.copy_and_replace(key, new_value) + new_id = int(getattr(new_token, key)) + old_id = int(getattr(self, key)) + + if old_id < new_id: + return new_token + else: + return self + + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": + return attr.evolve(self, **{key: new_value}) + + +StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class PersistedEventPosition: + """Position of a newly persisted event with instance that persisted it. + + This can be used to test whether the event is persisted before or after a + RoomStreamToken. + """ + + instance_name: str + stream: int + + def persisted_after(self, token: RoomStreamToken) -> bool: + return token.get_stream_pos_for_instance(self.instance_name) < self.stream + + def to_room_stream_token(self) -> RoomStreamToken: + """Converts the position to a room stream token such that events + persisted in the same room after this position will be after the + returned `RoomStreamToken`. + + Note: no guarantees are made about ordering w.r.t. events in other + rooms. + """ + # Doing the naive thing satisfies the desired properties described in + # the docstring. + return RoomStreamToken(None, self.stream) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThirdPartyInstanceID: + appservice_id: Optional[str] + network_id: Optional[str] + + # Deny iteration because it will bite you if you try to create a singleton + # set by: + # users = set(user) + def __iter__(self) -> NoReturn: + raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) + + # Because this class is a frozen class, it is deeply immutable. + def __copy__(self) -> "ThirdPartyInstanceID": + return self + + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": + return self + + @classmethod + def from_string(cls, s: str) -> "ThirdPartyInstanceID": + bits = s.split("|", 2) + if len(bits) != 2: + raise SynapseError(400, "Invalid ID %r" % (s,)) + + return cls(appservice_id=bits[0], network_id=bits[1]) + + def to_string(self) -> str: + return "%s|%s" % (self.appservice_id, self.network_id) + + __str__ = to_string + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ReadReceipt: + """Information about a read-receipt""" + + room_id: str + receipt_type: str + user_id: str + event_ids: List[str] + thread_id: Optional[str] + data: JsonDict + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class DeviceListUpdates: + """ + An object containing a diff of information regarding other users' device lists, intended for + a recipient to carry out device list tracking. + + Attributes: + changed: A set of users whose device lists have changed recently. + left: A set of users who the recipient no longer needs to track the device lists of. + Typically when those users no longer share any end-to-end encryption enabled rooms. + """ + + # We need to use a factory here, otherwise `set` is not evaluated at + # object instantiation, but instead at class definition instantiation. + # The latter happening only once, thus always giving you the same sets + # across multiple DeviceListUpdates instances. + # Also see: don't define mutable default arguments. + changed: Set[str] = attr.ib(factory=set) + left: Set[str] = attr.ib(factory=set) + + def __bool__(self) -> bool: + return bool(self.changed or self.left) + + +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info: a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + the key ID and verify key for the cross-signing key + """ + # make sure that a `keys` field is provided + if "keys" not in key_info: + raise ValueError("Invalid key") + keys = key_info["keys"] + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) + return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class UserInfo: + """Holds information about a user. Result of get_userinfo_by_id. + + Attributes: + user_id: ID of the user. + appservice_id: Application service ID that created this user. + consent_server_notice_sent: Version of policy documents the user has been sent. + consent_version: Version of policy documents the user has consented to. + creation_ts: Creation timestamp of the user. + is_admin: True if the user is an admin. + is_deactivated: True if the user has been deactivated. + is_guest: True if the user is a guest user. + is_shadow_banned: True if the user has been shadow-banned. + user_type: User type (None for normal user, 'support' and 'bot' other options). + """ + + user_id: UserID + appservice_id: Optional[int] + consent_server_notice_sent: Optional[str] + consent_version: Optional[str] + user_type: Optional[str] + creation_ts: int + is_admin: bool + is_deactivated: bool + is_guest: bool + is_shadow_banned: bool + + +class UserProfile(TypedDict): + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RetentionPolicy: + min_lifetime: Optional[int] = None + max_lifetime: Optional[int] = None diff --git a/synapse/types/state.py b/synapse/types/state.py new file mode 100644 index 0000000000..0004d955b4 --- /dev/null +++ b/synapse/types/state.py @@ -0,0 +1,567 @@ +# Copyright 2014-2016 OpenMarket Ltd +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, +) + +import attr +from frozendict import frozendict + +from synapse.api.constants import EventTypes +from synapse.types import MutableStateMap, StateKey, StateMap + +if TYPE_CHECKING: + from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad + + +logger = logging.getLogger(__name__) + +# Used for generic functions below +T = TypeVar("T") + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class StateFilter: + """A filter used when querying for state. + + Attributes: + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not + appear in `types`. + """ + + types: "frozendict[str, Optional[FrozenSet[str]]]" + include_others: bool = False + + def __attrs_post_init__(self) -> None: + # If `include_others` is set we canonicalise the filter by removing + # wildcards from the types dictionary + if self.include_others: + # this is needed to work around the fact that StateFilter is frozen + object.__setattr__( + self, + "types", + frozendict({k: v for k, v in self.types.items() if v is not None}), + ) + + @staticmethod + def all() -> "StateFilter": + """Returns a filter that fetches everything. + + Returns: + The state filter. + """ + return _ALL_STATE_FILTER + + @staticmethod + def none() -> "StateFilter": + """Returns a filter that fetches nothing. + + Returns: + The new state filter. + """ + return _NONE_STATE_FILTER + + @staticmethod + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": + """Creates a filter that only fetches the given types + + Args: + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type + + Returns: + The new state filter. + """ + type_dict: Dict[str, Optional[Set[str]]] = {} + for typ, s in types: + if typ in type_dict: + if type_dict[typ] is None: + continue + + if s is None: + type_dict[typ] = None + continue + + type_dict.setdefault(typ, set()).add(s) # type: ignore + + return StateFilter( + types=frozendict( + (k, frozenset(v) if v is not None else None) + for k, v in type_dict.items() + ) + ) + + @staticmethod + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": + """Creates a filter that returns all non-member events, plus the member + events for the given users + + Args: + members: Set of user IDs + + Returns: + The new state filter + """ + return StateFilter( + types=frozendict({EventTypes.Member: frozenset(members)}), + include_others=True, + ) + + @staticmethod + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": + """ + Returns a (frozen) StateFilter with the same contents as the parameters + specified here, which can be made of mutable types. + """ + types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {} + for state_types, state_keys in types.items(): + if state_keys is not None: + types_with_frozen_values[state_types] = frozenset(state_keys) + else: + types_with_frozen_values[state_types] = None + + return StateFilter( + frozendict(types_with_frozen_values), include_others=include_others + ) + + def return_expanded(self) -> "StateFilter": + """Creates a new StateFilter where type wild cards have been removed + (except for memberships). The returned filter is a superset of the + current one, i.e. anything that passes the current filter will pass + the returned filter. + + This helps the caching as the DictionaryCache knows if it has *all* the + state, but does not know if it has all of the keys of a particular type, + which makes wildcard lookups expensive unless we have a complete cache. + Hence, if we are doing a wildcard lookup, populate the cache fully so + that we can do an efficient lookup next time. + + Note that since we have two caches, one for membership events and one for + other events, we can be a bit more clever than simply returning + `StateFilter.all()` if `has_wildcards()` is True. + + We return a StateFilter where: + 1. the list of membership events to return is the same + 2. if there is a wildcard that matches non-member events we + return all non-member events + + Returns: + The new state filter. + """ + + if self.is_full(): + # If we're going to return everything then there's nothing to do + return self + + if not self.has_wildcards(): + # If there are no wild cards, there's nothing to do + return self + + if EventTypes.Member in self.types: + get_all_members = self.types[EventTypes.Member] is None + else: + get_all_members = self.include_others + + has_non_member_wildcard = self.include_others or any( + state_keys is None + for t, state_keys in self.types.items() + if t != EventTypes.Member + ) + + if not has_non_member_wildcard: + # If there are no non-member wild cards we can just return ourselves + return self + + if get_all_members: + # We want to return everything. + return StateFilter.all() + elif EventTypes.Member in self.types: + # We want to return all non-members, but only particular + # memberships + return StateFilter( + types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}), + include_others=True, + ) + else: + # We want to return all non-members + return _ALL_NON_MEMBER_STATE_FILTER + + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: + """Converts the filter to an SQL clause. + + For example: + + f = StateFilter.from_types([("m.room.create", "")]) + clause, args = f.make_sql_filter_clause() + clause == "(type = ? AND state_key = ?)" + args == ['m.room.create', ''] + + + Returns: + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). + """ + + where_clause = "" + where_args: List[str] = [] + + if self.is_full(): + return where_clause, where_args + + if not self.include_others and not self.types: + # i.e. this is an empty filter, so we need to return a clause that + # will match nothing + return "1 = 2", [] + + # First we build up a lost of clauses for each type/state_key combo + clauses = [] + for etype, state_keys in self.types.items(): + if state_keys is None: + clauses.append("(type = ?)") + where_args.append(etype) + continue + + for state_key in state_keys: + clauses.append("(type = ? AND state_key = ?)") + where_args.extend((etype, state_key)) + + # This will match anything that appears in `self.types` + where_clause = " OR ".join(clauses) + + # If we want to include stuff that's not in the types dict then we add + # a `OR type NOT IN (...)` clause to the end. + if self.include_others: + if where_clause: + where_clause += " OR " + + where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),) + where_args.extend(self.types) + + return where_clause, where_args + + def max_entries_returned(self) -> Optional[int]: + """Returns the maximum number of entries this filter will return if + known, otherwise returns None. + + For example a simple state filter asking for `("m.room.create", "")` + will return 1, whereas the default state filter will return None. + + This is used to bail out early if the right number of entries have been + fetched. + """ + if self.has_wildcards(): + return None + + return len(self.concrete_types()) + + def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]: + """Returns the state filtered with by this StateFilter. + + Args: + state: The state map to filter + + Returns: + The filtered state map. + This is a copy, so it's safe to mutate. + """ + if self.is_full(): + return dict(state_dict) + + filtered_state = {} + for k, v in state_dict.items(): + typ, state_key = k + if typ in self.types: + state_keys = self.types[typ] + if state_keys is None or state_key in state_keys: + filtered_state[k] = v + elif self.include_others: + filtered_state[k] = v + + return filtered_state + + def is_full(self) -> bool: + """Whether this filter fetches everything or not + + Returns: + True if the filter fetches everything. + """ + return self.include_others and not self.types + + def has_wildcards(self) -> bool: + """Whether the filter includes wildcards or is attempting to fetch + specific state. + + Returns: + True if the filter includes wildcards. + """ + + return self.include_others or any( + state_keys is None for state_keys in self.types.values() + ) + + def concrete_types(self) -> List[Tuple[str, str]]: + """Returns a list of concrete type/state_keys (i.e. not None) that + will be fetched. This will be a complete list if `has_wildcards` + returns False, but otherwise will be a subset (or even empty). + + Returns: + A list of type/state_keys tuples. + """ + return [ + (t, s) + for t, state_keys in self.types.items() + if state_keys is not None + for s in state_keys + ] + + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: + """Return the filter split into two: one which assumes it's exclusively + matching against member state, and one which assumes it's matching + against non member state. + + This is useful due to the returned filters giving correct results for + `is_full()`, `has_wildcards()`, etc, when operating against maps that + either exclusively contain member events or only contain non-member + events. (Which is the case when dealing with the member vs non-member + state caches). + + Returns: + The member and non member filters + """ + + if EventTypes.Member in self.types: + state_keys = self.types[EventTypes.Member] + if state_keys is None: + member_filter = StateFilter.all() + else: + member_filter = StateFilter(frozendict({EventTypes.Member: state_keys})) + elif self.include_others: + member_filter = StateFilter.all() + else: + member_filter = StateFilter.none() + + non_member_filter = StateFilter( + types=frozendict( + {k: v for k, v in self.types.items() if k != EventTypes.Member} + ), + include_others=self.include_others, + ) + + return member_filter, non_member_filter + + def _decompose_into_four_parts( + self, + ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]: + """ + Decomposes this state filter into 4 constituent parts, which can be + thought of as this: + all? - minus_wildcards + plus_wildcards + plus_state_keys + + where + * all represents ALL state + * minus_wildcards represents entire state types to remove + * plus_wildcards represents entire state types to add + * plus_state_keys represents individual state keys to add + + See `recompose_from_four_parts` for the other direction of this + correspondence. + """ + is_all = self.include_others + excluded_types: Set[str] = {t for t in self.types if is_all} + wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None} + concrete_keys: Set[StateKey] = set(self.concrete_types()) + + return (is_all, excluded_types), (wildcard_types, concrete_keys) + + @staticmethod + def _recompose_from_four_parts( + all_part: bool, + minus_wildcards: Set[str], + plus_wildcards: Set[str], + plus_state_keys: Set[StateKey], + ) -> "StateFilter": + """ + Recomposes a state filter from 4 parts. + + See `decompose_into_four_parts` (the other direction of this + correspondence) for descriptions on each of the parts. + """ + + # {state type -> set of state keys OR None for wildcard} + # (The same structure as that of a StateFilter.) + new_types: Dict[str, Optional[Set[str]]] = {} + + # if we start with all, insert the excluded statetypes as empty sets + # to prevent them from being included + if all_part: + new_types.update({state_type: set() for state_type in minus_wildcards}) + + # insert the plus wildcards + new_types.update({state_type: None for state_type in plus_wildcards}) + + # insert the specific state keys + for state_type, state_key in plus_state_keys: + if state_type in new_types: + entry = new_types[state_type] + if entry is not None: + entry.add(state_key) + elif not all_part: + # don't insert if the entire type is already included by + # include_others as this would actually shrink the state allowed + # by this filter. + new_types[state_type] = {state_key} + + return StateFilter.freeze(new_types, include_others=all_part) + + def approx_difference(self, other: "StateFilter") -> "StateFilter": + """ + Returns a state filter which represents `self - other`. + + This is useful for determining what state remains to be pulled out of the + database if we want the state included by `self` but already have the state + included by `other`. + + The returned state filter + - MUST include all state events that are included by this filter (`self`) + unless they are included by `other`; + - MUST NOT include state events not included by this filter (`self`); and + - MAY be an over-approximation: the returned state filter + MAY additionally include some state events from `other`. + + This implementation attempts to return the narrowest such state filter. + In the case that `self` contains wildcards for state types where + `other` contains specific state keys, an approximation must be made: + the returned state filter keeps the wildcard, as state filters are not + able to express 'all state keys except some given examples'. + e.g. + StateFilter(m.room.member -> None (wildcard)) + minus + StateFilter(m.room.member -> {'@wombat:example.org'}) + is approximated as + StateFilter(m.room.member -> None (wildcard)) + """ + + # We first transform self and other into an alternative representation: + # - whether or not they include all events to begin with ('all') + # - if so, which event types are excluded? ('excludes') + # - which entire event types to include ('wildcards') + # - which concrete state keys to include ('concrete state keys') + (self_all, self_excludes), ( + self_wildcards, + self_concrete_keys, + ) = self._decompose_into_four_parts() + (other_all, other_excludes), ( + other_wildcards, + other_concrete_keys, + ) = other._decompose_into_four_parts() + + # Start with an estimate of the difference based on self + new_all = self_all + # Wildcards from the other can be added to the exclusion filter + new_excludes = self_excludes | other_wildcards + # We remove wildcards that appeared as wildcards in the other + new_wildcards = self_wildcards - other_wildcards + # We filter out the concrete state keys that appear in the other + # as wildcards or concrete state keys. + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in self_concrete_keys + if state_type not in other_wildcards + } - other_concrete_keys + + if other_all: + if self_all: + # If self starts with all, then we add as wildcards any + # types which appear in the other's exclusion filter (but + # aren't in the self exclusion filter). This is as the other + # filter will return everything BUT the types in its exclusion, so + # we need to add those excluded types that also match the self + # filter as wildcard types in the new filter. + new_wildcards |= other_excludes.difference(self_excludes) + + # If other is an `include_others` then the difference isn't. + new_all = False + # (We have no need for excludes when we don't start with all, as there + # is nothing to exclude.) + new_excludes = set() + + # We also filter out all state types that aren't in the exclusion + # list of the other. + new_wildcards &= other_excludes + new_concrete_keys = { + (state_type, state_key) + for (state_type, state_key) in new_concrete_keys + if state_type in other_excludes + } + + # Transform our newly-constructed state filter from the alternative + # representation back into the normal StateFilter representation. + return StateFilter._recompose_from_four_parts( + new_all, new_excludes, new_wildcards, new_concrete_keys + ) + + def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool: + """Check if we need to wait for full state to complete to calculate this state + + If we have a state filter which is completely satisfied even with partial + state, then we don't need to await_full_state before we can return it. + + Args: + is_mine_id: a callable which confirms if a given state_key matches a mxid + of a local user + """ + # if we haven't requested membership events, then it depends on the value of + # 'include_others' + if EventTypes.Member not in self.types: + return self.include_others + + # if we're looking for *all* membership events, then we have to wait + member_state_keys = self.types[EventTypes.Member] + if member_state_keys is None: + return True + + # otherwise, consider whose membership we are looking for. If it's entirely + # local users, then we don't need to wait. + for state_key in member_state_keys: + if not is_mine_id(state_key): + # remote user + return True + + # local users only + return False + + +_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True) +_ALL_NON_MEMBER_STATE_FILTER = StateFilter( + types=frozendict({EventTypes.Member: frozenset()}), include_others=True +) +_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) diff --git a/synapse/visibility.py b/synapse/visibility.py index b443857571..e442de3173 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -26,8 +26,8 @@ from synapse.events.utils import prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore -from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import Clock logger = logging.getLogger(__name__) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index d4e6d4236c..a433e70870 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -22,8 +22,8 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.server import HomeServer -from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, StateMap, UserID +from synapse.types.state import StateFilter from synapse.util import Clock from tests.unittest import HomeserverTestCase, TestCase -- cgit 1.4.1