diff options
Diffstat (limited to 'synapse')
62 files changed, 1004 insertions, 556 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 3d7f986ac7..66e869bc2d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -32,7 +32,6 @@ from synapse.appservice import ApplicationService from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging.opentracing import ( - SynapseTags, active_span, force_tracing, start_active_span, @@ -162,12 +161,6 @@ class Auth: parent_span.set_tag( "authenticated_entity", requester.authenticated_entity ) - # We tag the Synapse instance name so that it's an easy jumping - # off point into the logs. Can also be used to filter for an - # instance that is under load. - parent_span.set_tag( - SynapseTags.INSTANCE_NAME, self.hs.get_instance_name() - ) parent_span.set_tag("user_id", requester.user.to_string()) if requester.device_id is not None: parent_span.set_tag("device_id", requester.device_id) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index c2c177fd71..9235ce6536 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -751,3 +751,25 @@ class ModuleFailedException(Exception): Raised when a module API callback fails, for example because it raised an exception. """ + + +class PartialStateConflictError(SynapseError): + """An internal error raised when attempting to persist an event with partial state + after the room containing the event has been un-partial stated. + + This error should be handled by recomputing the event context and trying again. + + This error has an HTTP status code so that it can be transported over replication. + It should not be exposed to clients. + """ + + @staticmethod + def message() -> str: + return "Cannot persist partial state event in un-partial stated room" + + def __init__(self) -> None: + super().__init__( + HTTPStatus.CONFLICT, + msg=PartialStateConflictError.message(), + errcode=Codes.UNKNOWN, + ) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 83c42fc25a..b9f432cc23 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -219,9 +219,13 @@ class FilterCollection: self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {})) self._room_state_filter = Filter(hs, room_filter_json.get("state", {})) self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {})) - self._room_account_data = Filter(hs, room_filter_json.get("account_data", {})) + self._room_account_data_filter = Filter( + hs, room_filter_json.get("account_data", {}) + ) self._presence_filter = Filter(hs, filter_json.get("presence", {})) - self._account_data = Filter(hs, filter_json.get("account_data", {})) + self._global_account_data_filter = Filter( + hs, filter_json.get("account_data", {}) + ) self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.event_fields = filter_json.get("event_fields", []) @@ -256,8 +260,10 @@ class FilterCollection: ) -> List[UserPresenceState]: return await self._presence_filter.filter(presence_states) - async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: - return await self._account_data.filter(events) + async def filter_global_account_data( + self, events: Iterable[JsonDict] + ) -> List[JsonDict]: + return await self._global_account_data_filter.filter(events) async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: return await self._room_state_filter.filter( @@ -279,7 +285,7 @@ class FilterCollection: async def filter_room_account_data( self, events: Iterable[JsonDict] ) -> List[JsonDict]: - return await self._room_account_data.filter( + return await self._room_account_data_filter.filter( await self._room_filter.filter(events) ) @@ -292,6 +298,13 @@ class FilterCollection: or self._presence_filter.filters_all_senders() ) + def blocks_all_global_account_data(self) -> bool: + """True if all global acount data will be filtered out.""" + return ( + self._global_account_data_filter.filters_all_types() + or self._global_account_data_filter.filters_all_senders() + ) + def blocks_all_room_ephemeral(self) -> bool: return ( self._room_ephemeral_filter.filters_all_types() @@ -299,6 +312,13 @@ class FilterCollection: or self._room_ephemeral_filter.filters_all_rooms() ) + def blocks_all_room_account_data(self) -> bool: + return ( + self._room_account_data_filter.filters_all_types() + or self._room_account_data_filter.filters_all_senders() + or self._room_account_data_filter.filters_all_rooms() + ) + def blocks_all_room_timeline(self) -> bool: return ( self._room_timeline_filter.filters_all_types() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 53db1e85b3..897dd3edac 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -15,7 +15,7 @@ import logging import math import resource import sys -from typing import TYPE_CHECKING, List, Sized, Tuple +from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple from prometheus_client import Gauge @@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None: @wrap_as_background_process("generate_monthly_active_users") async def generate_monthly_active_users() -> None: current_mau_count = 0 - current_mau_count_by_service = {} + current_mau_count_by_service: Mapping[str, int] = {} reserved_users: Sized = () store = hs.get_datastores().main if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 53c0682dfd..6ac2f0c10d 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -169,6 +169,16 @@ class ExperimentalConfig(Config): # MSC3925: do not replace events with their edits self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False) + # MSC3758: exact_event_match push rule condition + self.msc3758_exact_event_match = experimental.get( + "msc3758_exact_event_match", False + ) + + # MSC3873: Disambiguate event_match keys. + self.msc3783_escape_event_match_key = experimental.get( + "msc3783_escape_event_match_key", False + ) + # MSC3952: Intentional mentions self.msc3952_intentional_mentions = experimental.get( "msc3952_intentional_mentions", False diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 3ed236217f..8666c22f01 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Collection from matrix_common.regex import glob_to_regex @@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config): return False def is_publishing_room_allowed( - self, user_id: str, room_id: str, aliases: List[str] + self, user_id: str, room_id: str, aliases: Collection[str] ) -> bool: """Checks if the given user is allowed to publish the room @@ -122,7 +122,7 @@ class _RoomDirectoryRule: except Exception as e: raise ConfigError("Failed to parse glob into regex") from e - def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool: + def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool: """Tests if this rule matches the given user_id, room_id and aliases. Args: diff --git a/synapse/event_auth.py b/synapse/event_auth.py index e0be9f88cc..4d6d1b8ebd 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -16,18 +16,7 @@ import collections.abc import logging import typing -from typing import ( - Any, - Collection, - Dict, - Iterable, - List, - Mapping, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -56,7 +45,13 @@ from synapse.api.room_versions import ( RoomVersions, ) from synapse.storage.databases.main.events_worker import EventRedactBehaviour -from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id +from synapse.types import ( + MutableStateMap, + StateMap, + StrCollection, + UserID, + get_domain_from_id, +) if typing.TYPE_CHECKING: # conditional imports to avoid import cycle @@ -69,7 +64,7 @@ logger = logging.getLogger(__name__) class _EventSourceStore(Protocol): async def get_events( self, - event_ids: Collection[str], + event_ids: StrCollection, redact_behaviour: EventRedactBehaviour, get_prev_content: bool = False, allow_rejected: bool = False, diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8aca9a3ab9..91118a8d84 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -39,7 +39,7 @@ from unpaddedbase64 import encode_base64 from synapse.api.constants import RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.types import JsonDict, RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken, StrCollection from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze from synapse.util.stringutils import strtobool @@ -413,7 +413,7 @@ class EventBase(metaclass=abc.ABCMeta): """ return [e for e, _ in self._dict["prev_events"]] - def auth_event_ids(self) -> Sequence[str]: + def auth_event_ids(self) -> StrCollection: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. @@ -558,7 +558,7 @@ class FrozenEventV2(EventBase): """ return self._dict["prev_events"] - def auth_event_ids(self) -> Sequence[str]: + def auth_event_ids(self) -> StrCollection: """Returns the list of auth event IDs. The order matches the order specified in the event, though there is no meaning to it. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 94dd1298e1..c82745275f 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey @@ -103,7 +103,7 @@ class EventBuilder: async def build( self, - prev_event_ids: List[str], + prev_event_ids: Collection[str], auth_event_ids: Optional[List[str]], depth: Optional[int] = None, ) -> EventBase: @@ -136,7 +136,7 @@ class EventBuilder: format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. - prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]] auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]] if format_version == EventFormatVersions.ROOM_V1_V2: auth_events = await self._store.add_event_hashes(auth_event_ids) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6eaef8b57a..e0d82ad81c 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, List, Optional, Tuple import attr @@ -26,8 +27,51 @@ if TYPE_CHECKING: from synapse.types.state import StateFilter +class UnpersistedEventContextBase(ABC): + """ + This is a base class for EventContext and UnpersistedEventContext, objects which + hold information relevant to storing an associated event. Note that an + UnpersistedEventContexts must be converted into an EventContext before it is + suitable to send to the db with its associated event. + + Attributes: + _storage: storage controllers for interfacing with the database + app_service: If the associated event is being sent by a (local) application service, that + app service. + """ + + def __init__(self, storage_controller: "StorageControllers"): + self._storage: "StorageControllers" = storage_controller + self.app_service: Optional[ApplicationService] = None + + @abstractmethod + async def persist( + self, + event: EventBase, + ) -> "EventContext": + """ + A method to convert an UnpersistedEventContext to an EventContext, suitable for + sending to the database with the associated event. + """ + pass + + @abstractmethod + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: + """ + Gets the room state at the event (ie not including the event if the event is a + state event). + + Args: + state_filter: specifies the type of state event to fetch from DB, example: + EventTypes.JoinRules + """ + pass + + @attr.s(slots=True, auto_attribs=True) -class EventContext: +class EventContext(UnpersistedEventContextBase): """ Holds information relevant to persisting an event @@ -77,9 +121,6 @@ class EventContext: delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` and ``state_group``. - app_service: If this event is being sent by a (local) application service, that - app service. - partial_state: if True, we may be storing this event with a temporary, incomplete state. """ @@ -122,6 +163,9 @@ class EventContext: """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) + async def persist(self, event: EventBase) -> "EventContext": + return self + async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -254,6 +298,128 @@ class EventContext: ) +@attr.s(slots=True, auto_attribs=True) +class UnpersistedEventContext(UnpersistedEventContextBase): + """ + The event context holds information about the state groups for an event. It is important + to remember that an event technically has two state groups: the state group before the + event, and the state group after the event. If the event is not a state event, the state + group will not change (ie the state group before the event will be the same as the state + group after the event), but if it is a state event the state group before the event + will differ from the state group after the event. + This is a version of an EventContext before the new state group (if any) has been + computed and stored. It contains information about the state before the event (which + also may be the information after the event, if the event is not a state event). The + UnpersistedEventContext must be converted into an EventContext by calling the method + 'persist' on it before it is suitable to be sent to the DB for processing. + + state_group_after_event: + The state group after the event. This will always be None until it is persisted. + If the event is not a state event, this will be the same as + state_group_before_event. + + state_group_before_event: + The ID of the state group representing the state of the room before this event. + + state_delta_due_to_event: + If the event is a state event, then this is the delta of the state between + `state_group` and `state_group_before_event` + + prev_group_for_state_group_before_event: + If it is known, ``state_group_before_event``'s previous state group. + + delta_ids_to_state_group_before_event: + If ``prev_group_for_state_group_before_event`` is not None, the state delta + between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``. + + partial_state: + Whether the event has partial state. + + state_map_before_event: + A map of the state before the event, i.e. the state at `state_group_before_event` + """ + + _storage: "StorageControllers" + state_group_before_event: Optional[int] + state_group_after_event: Optional[int] + state_delta_due_to_event: Optional[dict] + prev_group_for_state_group_before_event: Optional[int] + delta_ids_to_state_group_before_event: Optional[StateMap[str]] + partial_state: bool + state_map_before_event: Optional[StateMap[str]] = None + + async def get_prev_state_ids( + self, state_filter: Optional["StateFilter"] = None + ) -> StateMap[str]: + """ + Gets the room state map, excluding this event. + + Args: + state_filter: specifies the type of state event to fetch from DB + + Returns: + Maps a (type, state_key) to the event ID of the state event matching + this tuple. + """ + if self.state_map_before_event: + return self.state_map_before_event + + assert self.state_group_before_event is not None + return await self._storage.state.get_state_ids_for_group( + self.state_group_before_event, state_filter + ) + + async def persist(self, event: EventBase) -> EventContext: + """ + Creates a full `EventContext` for the event, persisting any referenced state that + has not yet been persisted. + + Args: + event: event that the EventContext is associated with. + + Returns: An EventContext suitable for sending to the database with the event + for persisting + """ + assert self.partial_state is not None + + # If we have a full set of state for before the event but don't have a state + # group for that state, we need to get one + if self.state_group_before_event is None: + assert self.state_map_before_event + state_group_before_event = await self._storage.state.store_state_group( + event.event_id, + event.room_id, + prev_group=self.prev_group_for_state_group_before_event, + delta_ids=self.delta_ids_to_state_group_before_event, + current_state_ids=self.state_map_before_event, + ) + self.state_group_before_event = state_group_before_event + + # if the event isn't a state event the state group doesn't change + if not self.state_delta_due_to_event: + state_group_after_event = self.state_group_before_event + + # otherwise if it is a state event we need to get a state group for it + else: + state_group_after_event = await self._storage.state.store_state_group( + event.event_id, + event.room_id, + prev_group=self.state_group_before_event, + delta_ids=self.state_delta_due_to_event, + current_state_ids=None, + ) + + return EventContext.with_state( + storage=self._storage, + state_group=state_group_after_event, + state_group_before_event=self.state_group_before_event, + state_delta_due_to_event=self.state_delta_due_to_event, + partial_state=self.partial_state, + prev_group=self.state_group_before_event, + delta_ids=self.state_delta_due_to_event, + ) + + def _encode_state_dict( state_dict: Optional[StateMap[str]], ) -> Optional[List[Tuple[str, str, str]]]: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 72ab696898..97c61cc258 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError from synapse.api.errors import ModuleFailedException, SynapseError from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import UnpersistedEventContextBase from synapse.storage.roommember import ProfileInfo from synapse.types import Requester, StateMap from synapse.util.async_helpers import delay_cancellation, maybe_awaitable @@ -231,7 +231,9 @@ class ThirdPartyEventRules: self._on_threepid_bind_callbacks.append(on_threepid_bind) async def check_event_allowed( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: UnpersistedEventContextBase, ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8d36172484..6d99845de5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -23,6 +23,7 @@ from typing import ( Collection, Dict, List, + Mapping, Optional, Tuple, Union, @@ -47,6 +48,7 @@ from synapse.api.errors import ( FederationError, IncompatibleRoomVersionError, NotFoundError, + PartialStateConflictError, SynapseError, UnsupportedRoomVersionError, ) @@ -80,7 +82,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary @@ -1512,7 +1513,7 @@ class FederationHandlerRegistry: def _get_event_ids_for_partial_state_join( join_event: EventBase, prev_state_ids: StateMap[str], - summary: Dict[str, MemberSummary], + summary: Mapping[str, MemberSummary], ) -> Collection[str]: """Calculate state to be returned in a partial_state send_join diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 67e789eef7..797de46dbc 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -343,10 +343,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]): } ) - ( - account_data, - room_account_data, - ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) + account_data = await self.store.get_updated_global_account_data_for_user( + user_id, last_stream_id + ) + room_account_data = await self.store.get_updated_room_account_data_for_user( + user_id, last_stream_id + ) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 30f2d46c3c..57a6854b1e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1593,9 +1593,8 @@ class AuthHandler: if medium == "email": address = canonicalise_email(address) - identity_handler = self.hs.get_identity_handler() - result = await identity_handler.try_unbind_threepid( - user_id, {"medium": medium, "address": address, "id_server": id_server} + result = await self.hs.get_identity_handler().try_unbind_threepid( + user_id, medium, address, id_server ) await self.store.user_delete_threepid(user_id, medium, address) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index d74d135c0c..d24f649382 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -106,12 +106,7 @@ class DeactivateAccountHandler: for threepid in threepids: try: result = await self._identity_handler.try_unbind_threepid( - user_id, - { - "medium": threepid["medium"], - "address": threepid["address"], - "id_server": id_server, - }, + user_id, threepid["medium"], threepid["address"], id_server ) identity_server_supports_unbinding &= result except Exception: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 2ea52257cb..a5798e9483 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -14,7 +14,7 @@ import logging import string -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence from typing_extensions import Literal @@ -485,7 +485,8 @@ class DirectoryHandler: ) ) if canonical_alias: - room_aliases.append(canonical_alias) + # Ensure we do not mutate room_aliases. + room_aliases = list(room_aliases) + [canonical_alias] if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases @@ -528,7 +529,7 @@ class DirectoryHandler: async def get_aliases_for_room( self, requester: Requester, room_id: str - ) -> List[str]: + ) -> Sequence[str]: """ Get a list of the aliases that currently point to this room on this server """ diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d2188ca08f..43cbece21b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -159,19 +159,22 @@ class E2eKeysHandler: # A map of destination -> user ID -> device IDs. remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {} if remote_queries: - query_list: List[Tuple[str, Optional[str]]] = [] + user_ids = set() + user_and_device_ids: List[Tuple[str, str]] = [] for user_id, device_ids in remote_queries.items(): if device_ids: - query_list.extend( + user_and_device_ids.extend( (user_id, device_id) for device_id in device_ids ) else: - query_list.append((user_id, None)) + user_ids.add(user_id) ( user_ids_not_in_cache, remote_results, - ) = await self.store.get_user_devices_from_cache(query_list) + ) = await self.store.get_user_devices_from_cache( + user_ids, user_and_device_ids + ) # Check that the homeserver still shares a room with all cached users. # Note that this check may be slightly racy when a remote user leaves a diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index a23a8ce2a1..46dd63c3f0 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -202,7 +202,7 @@ class EventAuthHandler: state_ids: StateMap[str], room_version: RoomVersion, user_id: str, - prev_member_event: Optional[EventBase], + prev_membership: Optional[str], ) -> None: """ Check whether a user can join a room without an invite due to restricted join rules. @@ -214,15 +214,14 @@ class EventAuthHandler: state_ids: The state of the room as it currently is. room_version: The room version of the room being joined. user_id: The user joining the room. - prev_member_event: The current membership event for this user. + prev_membership: The current membership state for this user. `None` if the + user has never joined the room (equivalent to "leave"). Raises: AuthError if the user cannot join the room. """ # If the member is invited or currently joined, then nothing to do. - if prev_member_event and ( - prev_member_event.membership in (Membership.JOIN, Membership.INVITE) - ): + if prev_membership in (Membership.JOIN, Membership.INVITE): return # This is not a room with a restricted join rule, so we don't need to do the @@ -255,13 +254,14 @@ class EventAuthHandler: ) async def has_restricted_join_rules( - self, state_ids: StateMap[str], room_version: RoomVersion + self, partial_state_ids: StateMap[str], room_version: RoomVersion ) -> bool: """ Return if the room has the proper join rules set for access via rooms. Args: - state_ids: The state of the room as it currently is. + state_ids: The state of the room as it currently is. May be full or partial + state. room_version: The room version of the room to query. Returns: @@ -272,7 +272,7 @@ class EventAuthHandler: return False # If there's no join rule, then it defaults to invite (so this doesn't apply). - join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None) + join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None) if not join_rules_event_id: return False diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7f64130e0a..08727e4857 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -49,6 +49,7 @@ from synapse.api.errors import ( FederationPullAttemptBackoffError, HttpResponseException, NotFoundError, + PartialStateConflictError, RequestSendFailed, SynapseError, ) @@ -56,7 +57,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.crypto.event_signing import compute_event_signature from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError from synapse.http.servlet import assert_params_in_dict @@ -68,7 +69,6 @@ from synapse.replication.http.federation import ( ReplicationCleanRoomRestServlet, ReplicationStoreRoomOnOutlierMembershipRestServlet, ) -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter @@ -990,7 +990,10 @@ class FederationHandler: ) try: - event, context = await self.event_creation_handler.create_new_client_event( + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event( builder=builder ) except SynapseError as e: @@ -998,7 +1001,9 @@ class FederationHandler: raise # Ensure the user can even join the room. - await self._federation_event_handler.check_join_restrictions(context, event) + await self._federation_event_handler.check_join_restrictions( + unpersisted_context, event + ) # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` @@ -1178,7 +1183,7 @@ class FederationHandler: }, ) - event, context = await self.event_creation_handler.create_new_client_event( + event, _ = await self.event_creation_handler.create_new_client_event( builder=builder ) @@ -1228,12 +1233,13 @@ class FederationHandler: }, ) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event(builder=builder) event_allowed, _ = await self.third_party_event_rules.check_event_allowed( - event, context + event, unpersisted_context ) if not event_allowed: logger.warning("Creation of knock %s forbidden by third-party rules", event) @@ -1406,15 +1412,20 @@ class FederationHandler: try: ( event, - context, + unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context + ( + event, + unpersisted_context, + ) = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, unpersisted_context ) + context = await unpersisted_context.persist(event) + EventValidator().validate_new(event, self.config) # We need to tell the transaction queue to send this out, even @@ -1483,14 +1494,19 @@ class FederationHandler: try: ( event, - context, + unpersisted_context, ) = await self.event_creation_handler.create_new_client_event( builder=builder ) - event, context = await self.add_display_name_to_third_party_invite( - room_version_obj, event_dict, event, context + ( + event, + unpersisted_context, + ) = await self.add_display_name_to_third_party_invite( + room_version_obj, event_dict, event, unpersisted_context ) + context = await unpersisted_context.persist(event) + try: validate_event_for_room_version(event) await self._event_auth_handler.check_auth_rules_from_context(event) @@ -1522,8 +1538,8 @@ class FederationHandler: room_version_obj: RoomVersion, event_dict: JsonDict, event: EventBase, - context: EventContext, - ) -> Tuple[EventBase, EventContext]: + context: UnpersistedEventContextBase, + ) -> Tuple[EventBase, UnpersistedEventContextBase]: key = ( EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"], @@ -1557,11 +1573,14 @@ class FederationHandler: room_version_obj, event_dict ) EventValidator().validate_builder(builder) - event, context = await self.event_creation_handler.create_new_client_event( - builder=builder - ) + + ( + event, + unpersisted_context, + ) = await self.event_creation_handler.create_new_client_event(builder=builder) + EventValidator().validate_new(event, self.config) - return event, context + return event, unpersisted_context async def _check_signature(self, event: EventBase, context: EventContext) -> None: """ diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e037acbca2..b7136f8d1c 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -47,6 +47,7 @@ from synapse.api.errors import ( FederationError, FederationPullAttemptBackoffError, HttpResponseException, + PartialStateConflictError, RequestSendFailed, SynapseError, ) @@ -58,7 +59,7 @@ from synapse.event_auth import ( validate_event_for_room_version, ) from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import ( @@ -74,7 +75,6 @@ from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) from synapse.state import StateResolutionStore -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( PersistedEventPosition, @@ -426,7 +426,9 @@ class FederationEventHandler: return event, context async def check_join_restrictions( - self, context: EventContext, event: EventBase + self, + context: UnpersistedEventContextBase, + event: EventBase, ) -> None: """Check that restrictions in restricted join rules are matched @@ -439,16 +441,17 @@ class FederationEventHandler: # Check if the user is already in the room or invited to the room. user_id = event.state_key prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) - prev_member_event = None + prev_membership = None if prev_member_event_id: prev_member_event = await self._store.get_event(prev_member_event_id) + prev_membership = prev_member_event.membership # Check if the member should be allowed access via membership in a space. await self._event_auth_handler.check_restricted_join_rules( prev_state_ids, event.room_version, user_id, - prev_member_event, + prev_membership, ) @trace @@ -524,11 +527,57 @@ class FederationEventHandler: "Peristing join-via-remote %s (partial_state: %s)", event, partial_state ) with nested_logging_context(suffix=event.event_id): + if partial_state: + # When handling a second partial state join into a partial state room, + # the returned state will exclude the membership from the first join. To + # preserve prior memberships, we try to compute the partial state before + # the event ourselves if we know about any of the prev events. + # + # When we don't know about any of the prev events, it's fine to just use + # the returned state, since the new join will create a new forward + # extremity, and leave the forward extremity containing our prior + # memberships alone. + prev_event_ids = set(event.prev_event_ids()) + seen_event_ids = await self._store.have_events_in_timeline( + prev_event_ids + ) + missing_event_ids = prev_event_ids - seen_event_ids + + state_maps_to_resolve: List[StateMap[str]] = [] + + # Fetch the state after the prev events that we know about. + state_maps_to_resolve.extend( + ( + await self._state_storage_controller.get_state_groups_ids( + room_id, seen_event_ids, await_full_state=False + ) + ).values() + ) + + # When there are prev events we do not have the state for, we state + # resolve with the state returned by the remote homeserver. + if missing_event_ids or len(state_maps_to_resolve) == 0: + state_maps_to_resolve.append( + {(e.type, e.state_key): e.event_id for e in state} + ) + + state_ids_before_event = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version.identifier, + state_maps_to_resolve, + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) + ) + else: + state_ids_before_event = { + (e.type, e.state_key): e.event_id for e in state + } + context = await self._state_handler.compute_event_context( event, - state_ids_before_event={ - (e.type, e.state_key): e.event_id for e in state - }, + state_ids_before_event=state_ids_before_event, partial_state=partial_state, ) diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 848e46eb9b..bf0f7acf80 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -219,28 +219,31 @@ class IdentityHandler: data = json_decoder.decode(e.msg) # XXX WAT? return data - async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool: - """Attempt to remove a 3PID from an identity server, or if one is not provided, all - identity servers we're aware the binding is present on + async def try_unbind_threepid( + self, mxid: str, medium: str, address: str, id_server: Optional[str] + ) -> bool: + """Attempt to remove a 3PID from one or more identity servers. Args: mxid: Matrix user ID of binding to be removed - threepid: Dict with medium & address of binding to be - removed, and an optional id_server. + medium: The medium of the third-party ID. + address: The address of the third-party ID. + id_server: An identity server to attempt to unbind from. If None, + attempt to remove the association from all identity servers + known to potentially have it. Raises: - SynapseError: If we failed to contact the identity server + SynapseError: If we failed to contact one or more identity servers. Returns: - True on success, otherwise False if the identity - server doesn't support unbinding (or no identity server found to - contact). + True on success, otherwise False if the identity server doesn't + support unbinding (or no identity server to contact was found). """ - if threepid.get("id_server"): - id_servers = [threepid["id_server"]] + if id_server: + id_servers = [id_server] else: id_servers = await self.store.get_id_servers_user_bound( - user_id=mxid, medium=threepid["medium"], address=threepid["address"] + mxid, medium, address ) # We don't know where to unbind, so we don't have a choice but to return @@ -249,20 +252,21 @@ class IdentityHandler: changed = True for id_server in id_servers: - changed &= await self.try_unbind_threepid_with_id_server( - mxid, threepid, id_server + changed &= await self._try_unbind_threepid_with_id_server( + mxid, medium, address, id_server ) return changed - async def try_unbind_threepid_with_id_server( - self, mxid: str, threepid: dict, id_server: str + async def _try_unbind_threepid_with_id_server( + self, mxid: str, medium: str, address: str, id_server: str ) -> bool: """Removes a binding from an identity server Args: mxid: Matrix user ID of binding to be removed - threepid: Dict with medium & address of binding to be removed + medium: The medium of the third-party ID + address: The address of the third-party ID id_server: Identity server to unbind from Raises: @@ -286,7 +290,7 @@ class IdentityHandler: content = { "mxid": mxid, - "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, + "threepid": {"medium": medium, "address": address}, } # we abuse the federation http client to sign the request, but we have to send it @@ -319,12 +323,7 @@ class IdentityHandler: except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - await self.store.remove_user_bound_threepid( - user_id=mxid, - medium=threepid["medium"], - address=threepid["address"], - id_server=id_server, - ) + await self.store.remove_user_bound_threepid(mxid, medium, address, id_server) return changed diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 191529bd8e..1a29abde98 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -154,9 +154,8 @@ class InitialSyncHandler: tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = await self.store.get_account_data_for_user( - user_id - ) + account_data = await self.store.get_global_account_data_for_user(user_id) + account_data_by_room = await self.store.get_room_account_data_for_user(user_id) public_room_ids = await self.store.get_public_room_ids() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e688e00575..8f5b658d9d 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -38,6 +38,7 @@ from synapse.api.errors import ( Codes, ConsentNotGivenError, NotFoundError, + PartialStateConflictError, ShadowBanError, SynapseError, UnstableSpecAuthError, @@ -48,7 +49,7 @@ from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase, relation_from_event from synapse.events.builder import EventBuilder -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.utils import maybe_upsert_event_field from synapse.events.validator import EventValidator from synapse.handlers.directory import DirectoryHandler @@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( MutableStateMap, @@ -499,9 +499,9 @@ class EventCreationHandler: self.request_ratelimiter = hs.get_request_ratelimiter() - # We arbitrarily limit concurrent event creation for a room to 5. - # This is to stop us from diverging history *too* much. - self.limiter = Linearizer(max_count=5, name="room_event_creation_limit") + # We limit concurrent event creation for a room to 1. This prevents state resolution + # from occurring when sending bursts of events to a local room + self.limiter = Linearizer(max_count=1, name="room_event_creation_limit") self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator() @@ -708,7 +708,7 @@ class EventCreationHandler: builder.internal_metadata.historical = historical - event, context = await self.create_new_client_event( + event, unpersisted_context = await self.create_new_client_event( builder=builder, requester=requester, allow_no_prev_events=allow_no_prev_events, @@ -721,6 +721,8 @@ class EventCreationHandler: current_state_group=current_state_group, ) + context = await unpersisted_context.persist(event) + # In an ideal world we wouldn't need the second part of this condition. However, # this behaviour isn't spec'd yet, meaning we should be able to deactivate this # behaviour. Another reason is that this code is also evaluated each time a new @@ -1083,13 +1085,14 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for the event using the parameters state_map and current_state_group, thus these parameters must be provided in this case if for_batch is True. The subsequently created event and context are suitable for being batched up and bulk persisted to the database - with other similarly created events. + with other similarly created events. Note that this returns an UnpersistedEventContext, + which must be converted to an EventContext before it can be sent to the DB. Args: builder: @@ -1131,7 +1134,7 @@ class EventCreationHandler: batch persisting Returns: - Tuple of created event, context + Tuple of created event, UnpersistedEventContext """ # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender @@ -1192,9 +1195,16 @@ class EventCreationHandler: event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth ) - context = await self.state.compute_event_context_for_batched( - event, state_map, current_state_group + + context: UnpersistedEventContextBase = ( + await self.state.calculate_context_info( + event, + state_ids_before_event=state_map, + partial_state=False, + state_group_before_event=current_state_group, + ) ) + else: event = await builder.build( prev_event_ids=prev_event_ids, @@ -1244,16 +1254,17 @@ class EventCreationHandler: state_map_for_event[(data.event_type, data.state_key)] = state_id - context = await self.state.compute_event_context( + # TODO(faster_joins): check how MSC2716 works and whether we can have + # partial state here + # https://github.com/matrix-org/synapse/issues/13003 + context = await self.state.calculate_context_info( event, state_ids_before_event=state_map_for_event, - # TODO(faster_joins): check how MSC2716 works and whether we can have - # partial state here - # https://github.com/matrix-org/synapse/issues/13003 partial_state=False, ) + else: - context = await self.state.compute_event_context(event) + context = await self.state.calculate_context_info(event) if requester: context.app_service = requester.app_service @@ -2082,9 +2093,9 @@ class EventCreationHandler: async def _rebuild_event_after_third_party_rules( self, third_party_result: dict, original_event: EventBase - ) -> Tuple[EventBase, EventContext]: + ) -> Tuple[EventBase, UnpersistedEventContextBase]: # the third_party_event_rules want to replace the event. - # we do some basic checks, and then return the replacement event and context. + # we do some basic checks, and then return the replacement event. # Construct a new EventBuilder and validate it, which helps with the # rest of these checks. @@ -2138,5 +2149,6 @@ class EventCreationHandler: # we rebuild the event context, to be on the safe side. If nothing else, # delta_ids might need an update. - context = await self.state.compute_event_context(event) + context = await self.state.calculate_context_info(event) + return event, context diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 04c61ae3dd..2bacdebfb5 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService @@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): @staticmethod def filter_out_private_receipts( - rooms: List[JsonDict], user_id: str + rooms: Sequence[JsonDict], user_id: str ) -> List[JsonDict]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7ba7c4ff07..837dabb3b7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -43,6 +43,7 @@ from synapse.api.errors import ( Codes, LimitExceededError, NotFoundError, + PartialStateConflictError, StoreError, SynapseError, ) @@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations from synapse.module_api import NOT_SPAM from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.streams import EventSource from synapse.types import ( JsonDict, @@ -1076,7 +1076,7 @@ class RoomCreationHandler: state_map: MutableStateMap[str] = {} # current_state_group of last event created. Used for computing event context of # events to be batched - current_state_group = None + current_state_group: Optional[int] = None def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict: e = {"type": etype, "content": content} @@ -1928,6 +1928,6 @@ class RoomShutdownHandler: return { "kicked_users": kicked_users, "failed_to_kick_users": failed_to_kick_users, - "local_aliases": aliases_for_room, + "local_aliases": list(aliases_for_room), "new_room_id": new_room_id, } diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index d236cc09b5..a965c7ec76 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -26,7 +26,13 @@ from synapse.api.constants import ( GuestAccess, Membership, ) -from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + PartialStateConflictError, + ShadowBanError, + SynapseError, +) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event from synapse.events import EventBase @@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.logging import opentracing from synapse.module_api import NOT_SPAM -from synapse.storage.databases.main.events import PartialStateConflictError from synapse.types import ( JsonDict, Requester, @@ -56,6 +61,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class NoKnownServersError(SynapseError): + """No server already resident to the room was provided to the join/knock operation.""" + + def __init__(self, msg: str = "No known servers"): + super().__init__(404, msg) + + class RoomMemberHandler(metaclass=abc.ABCMeta): # TODO(paul): This handler currently contains a messy conflation of # low-level API that works on UserID objects and so on, and REST-level @@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: Room that we are trying to join user: User who is trying to join content: A dict that should be used as the content of the join event. + + Raises: + NoKnownServersError: if remote_room_hosts does not contain a server joined to + the room. """ raise NotImplementedError() @@ -484,7 +500,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): user_id: The user's ID. """ # Retrieve user account data for predecessor room - user_account_data, _ = await self.store.get_account_data_for_user(user_id) + user_account_data = await self.store.get_global_account_data_for_user(user_id) # Copy direct message state if applicable direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {}) @@ -823,14 +839,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): latest_event_ids = await self.store.get_prev_events_for_room(room_id) - state_before_join = await self.state_handler.compute_state_after_events( - room_id, latest_event_ids + is_partial_state_room = await self.store.is_partial_state_room(room_id) + partial_state_before_join = await self.state_handler.compute_state_after_events( + room_id, latest_event_ids, await_full_state=False ) + # `is_partial_state_room` also indicates whether `partial_state_before_join` is + # partial. # TODO: Refactor into dictionary of explicitly allowed transitions # between old and new state, with specific error messages for some # transitions and generic otherwise - old_state_id = state_before_join.get((EventTypes.Member, target.to_string())) + old_state_id = partial_state_before_join.get( + (EventTypes.Member, target.to_string()) + ) if old_state_id: old_state = await self.store.get_event(old_state_id, allow_none=True) old_membership = old_state.content.get("membership") if old_state else None @@ -881,11 +902,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if action == "kick": raise AuthError(403, "The target user is not in the room") - is_host_in_room = await self._is_host_in_room(state_before_join) + is_host_in_room = await self._is_host_in_room(partial_state_before_join) if effective_membership_state == Membership.JOIN: if requester.is_guest: - guest_can_join = await self._can_guest_join(state_before_join) + guest_can_join = await self._can_guest_join(partial_state_before_join) if not guest_can_join: # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. @@ -927,8 +948,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id, remote_room_hosts, content, + is_partial_state_room, is_host_in_room, - state_before_join, + partial_state_before_join, ) if remote_join: if ratelimit: @@ -1073,8 +1095,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): room_id: str, remote_room_hosts: List[str], content: JsonDict, + is_partial_state_room: bool, is_host_in_room: bool, - state_before_join: StateMap[str], + partial_state_before_join: StateMap[str], ) -> Tuple[bool, List[str]]: """ Check whether the server should do a remote join (as opposed to a local @@ -1093,9 +1116,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): remote_room_hosts: A list of remote room hosts. content: The content to use as the event body of the join. This may be modified. - is_host_in_room: True if the host is in the room. - state_before_join: The state before the join event (i.e. the resolution of - the states after its parent events). + is_partial_state_room: `True` if the server currently doesn't hold the full + state of the room. + is_host_in_room: `True` if the host is in the room. + partial_state_before_join: The state before the join event (i.e. the + resolution of the states after its parent events). May be full or + partial state, depending on `is_partial_state_room`. Returns: A tuple of: @@ -1109,6 +1135,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if not is_host_in_room: return True, remote_room_hosts + prev_member_event_id = partial_state_before_join.get( + (EventTypes.Member, user_id), None + ) + previous_membership = None + if prev_member_event_id: + prev_member_event = await self.store.get_event(prev_member_event_id) + previous_membership = prev_member_event.membership + + # If we are not fully joined yet, and the target is not already in the room, + # let's do a remote join so another server with the full state can validate + # that the user has not been banned for example. + # We could just accept the join and wait for state res to resolve that later on + # but we would then leak room history to this person until then, which is pretty + # bad. + if is_partial_state_room and previous_membership != Membership.JOIN: + return True, remote_room_hosts + # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) @@ -1116,21 +1159,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If restricted join rules are not being used, a local join can always # be used. if not await self.event_auth_handler.has_restricted_join_rules( - state_before_join, room_version + partial_state_before_join, room_version ): return False, [] # If the user is invited to the room or already joined, the join # event can always be issued locally. - prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None) - prev_member_event = None - if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) - if prev_member_event.membership in ( - Membership.JOIN, - Membership.INVITE, - ): - return False, [] + if previous_membership in (Membership.JOIN, Membership.INVITE): + return False, [] + + # All the partial state cases are covered above. We have been given the full + # state of the room. + assert not is_partial_state_room + state_before_join = partial_state_before_join # If the local host has a user who can issue invites, then a local # join can be done. @@ -1154,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Ensure the member should be allowed access via membership in a room. await self.event_auth_handler.check_restricted_join_rules( - state_before_join, room_version, user_id, prev_member_event + state_before_join, room_version, user_id, previous_membership ) # If this is going to be a local join, additional information must @@ -1304,11 +1345,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool: + async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool: """ Returns whether a guest can join a room based on its current state. + + Args: + partial_current_state_ids: The current state of the room. May be full or + partial state. """ - guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + guest_access_id = partial_current_state_ids.get( + (EventTypes.GuestAccess, ""), None + ) if not guest_access_id: return False @@ -1634,19 +1681,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) return event, stream_id - async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool: + async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool: + """Returns whether the homeserver is in the room based on its current state. + + Args: + partial_current_state_ids: The current state of the room. May be full or + partial state. + """ # Have we just created the room, and is this about to be the very # first member event? - create_event_id = current_state_ids.get(("m.room.create", "")) - if len(current_state_ids) == 1 and create_event_id: + create_event_id = partial_current_state_ids.get(("m.room.create", "")) + if len(partial_current_state_ids) == 1 and create_event_id: # We can only get here if we're in the process of creating the room return True - for etype, state_key in current_state_ids: + for etype, state_key in partial_current_state_ids: if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): continue - event_id = current_state_ids[(etype, state_key)] + event_id = partial_current_state_ids[(etype, state_key)] event = await self.store.get_event(event_id, allow_none=True) if not event: continue @@ -1715,8 +1768,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError( - 404, + raise NoKnownServersError( "Can't join remote room because no servers " "that are in the room have been provided.", ) @@ -1947,7 +1999,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ] if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise NoKnownServersError() return await self.federation_handler.do_knock( remote_room_hosts, room_id, user.to_string(), content=content diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 221552a2a6..ba261702d4 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -15,8 +15,7 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple -from synapse.api.errors import SynapseError -from synapse.handlers.room_member import RoomMemberHandler +from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler from synapse.replication.http.membership import ( ReplicationRemoteJoinRestServlet as ReplRemoteJoin, ReplicationRemoteKnockRestServlet as ReplRemoteKnock, @@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) -> Tuple[str, int]: """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") + raise NoKnownServersError() ret = await self._remote_join_client( requester=requester, diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 4472019fbc..807245160d 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -521,8 +521,8 @@ class RoomSummaryHandler: It should return true if: - * The requester is joined or can join the room (per MSC3173). - * The origin server has any user that is joined or can join the room. + * The requesting user is joined or can join the room (per MSC3173); or + * The origin server has any user that is joined or can join the room; or * The history visibility is set to world readable. Args: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 3566537894..4e4595312c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -269,6 +269,8 @@ class SyncHandler: self._state_storage_controller = self._storage_controllers.state self._device_handler = hs.get_device_handler() + self.should_calculate_push_rules = hs.config.push.enable_push + # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token # that sets 'since' to 'next_batch'), we know that device won't need a @@ -1288,6 +1290,12 @@ class SyncHandler: async def unread_notifs_for_room_id( self, room_id: str, sync_config: SyncConfig ) -> RoomNotifCounts: + if not self.should_calculate_push_rules: + # If push rules have been universally disabled then we know we won't + # have any unread counts in the DB, so we may as well skip asking + # the DB. + return RoomNotifCounts.empty() + with Measure(self.clock, "unread_notifs_for_room_id"): return await self.store.get_unread_event_push_actions_by_room_for_user( @@ -1391,6 +1399,11 @@ class SyncHandler: for room_id, is_partial_state in results.items() if is_partial_state ) + membership_change_events = [ + event + for event in membership_change_events + if not results.get(event.room_id, False) + ] # Incremental eager syncs should additionally include rooms that # - we are joined to @@ -1444,9 +1457,9 @@ class SyncHandler: logger.debug("Fetching account data") - account_data_by_room = await self._generate_sync_entry_for_account_data( - sync_result_builder - ) + # Global account data is included if it is not filtered out. + if not sync_config.filter_collection.blocks_all_global_account_data(): + await self._generate_sync_entry_for_account_data(sync_result_builder) # Presence data is included if the server has it enabled and not filtered out. include_presence_data = bool( @@ -1472,9 +1485,7 @@ class SyncHandler: ( newly_joined_rooms, newly_left_rooms, - ) = await self._generate_sync_entry_for_rooms( - sync_result_builder, account_data_by_room - ) + ) = await self._generate_sync_entry_for_rooms(sync_result_builder) # Work out which users have joined or left rooms we're in. We use this # to build the presence and device_list parts of the sync response in @@ -1521,7 +1532,7 @@ class SyncHandler: one_time_keys_count = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = ( + unused_fallback_key_types = list( await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) @@ -1717,35 +1728,29 @@ class SyncHandler: async def _generate_sync_entry_for_account_data( self, sync_result_builder: "SyncResultBuilder" - ) -> Dict[str, Dict[str, JsonDict]]: - """Generates the account data portion of the sync response. + ) -> None: + """Generates the global account data portion of the sync response. Account data (called "Client Config" in the spec) can be set either globally or for a specific room. Account data consists of a list of events which accumulate state, much like a room. - This function retrieves global and per-room account data. The former is written - to the given `sync_result_builder`. The latter is returned directly, to be - later written to the `sync_result_builder` on a room-by-room basis. + This function retrieves global account data and writes it to the given + `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling + of per-room account data. Args: sync_result_builder - - Returns: - A dictionary whose keys (room ids) map to the per room account data for that - room. """ sync_config = sync_result_builder.sync_config user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token if since_token and not sync_result_builder.full_state: - # TODO Do not fetch room account data if it will be unused. - ( - global_account_data, - account_data_by_room, - ) = await self.store.get_updated_account_data_for_user( - user_id, since_token.account_data_key + global_account_data = ( + await self.store.get_updated_global_account_data_for_user( + user_id, since_token.account_data_key + ) ) push_rules_changed = await self.store.have_push_rules_changed_for_user( @@ -1753,31 +1758,31 @@ class SyncHandler: ) if push_rules_changed: + global_account_data = dict(global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) else: - # TODO Do not fetch room account data if it will be unused. - ( - global_account_data, - account_data_by_room, - ) = await self.store.get_account_data_for_user(sync_config.user.to_string()) + all_global_account_data = await self.store.get_global_account_data_for_user( + user_id + ) + global_account_data = dict(all_global_account_data) global_account_data["m.push_rules"] = await self.push_rules_for_user( sync_config.user ) - account_data_for_user = await sync_config.filter_collection.filter_account_data( - [ - {"type": account_data_type, "content": content} - for account_data_type, content in global_account_data.items() - ] + account_data_for_user = ( + await sync_config.filter_collection.filter_global_account_data( + [ + {"type": account_data_type, "content": content} + for account_data_type, content in global_account_data.items() + ] + ) ) sync_result_builder.account_data = account_data_for_user - return account_data_by_room - async def _generate_sync_entry_for_presence( self, sync_result_builder: "SyncResultBuilder", @@ -1837,9 +1842,7 @@ class SyncHandler: sync_result_builder.presence = presence async def _generate_sync_entry_for_rooms( - self, - sync_result_builder: "SyncResultBuilder", - account_data_by_room: Dict[str, Dict[str, JsonDict]], + self, sync_result_builder: "SyncResultBuilder" ) -> Tuple[AbstractSet[str], AbstractSet[str]]: """Generates the rooms portion of the sync response. Populates the `sync_result_builder` with the result. @@ -1850,7 +1853,6 @@ class SyncHandler: Args: sync_result_builder - account_data_by_room: Dictionary of per room account data Returns: Returns a 2-tuple describing rooms the user has joined or left. @@ -1863,9 +1865,30 @@ class SyncHandler: since_token = sync_result_builder.since_token user_id = sync_result_builder.sync_config.user.to_string() + blocks_all_rooms = ( + sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + ) + + # 0. Start by fetching room account data (if required). + if ( + blocks_all_rooms + or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data() + ): + account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {} + elif since_token and not sync_result_builder.full_state: + account_data_by_room = ( + await self.store.get_updated_room_account_data_for_user( + user_id, since_token.account_data_key + ) + ) + else: + account_data_by_room = await self.store.get_room_account_data_for_user( + user_id + ) + # 1. Start by fetching all ephemeral events in rooms we've joined (if required). block_all_room_ephemeral = ( - sync_result_builder.sync_config.filter_collection.blocks_all_rooms() + blocks_all_rooms or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() ) if block_all_room_ephemeral: @@ -2291,8 +2314,8 @@ class SyncHandler: sync_result_builder: "SyncResultBuilder", room_builder: "RoomSyncResultBuilder", ephemeral: List[JsonDict], - tags: Optional[Dict[str, Dict[str, Any]]], - account_data: Dict[str, JsonDict], + tags: Optional[Mapping[str, Mapping[str, Any]]], + account_data: Mapping[str, JsonDict], always_include: bool = False, ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` diff --git a/synapse/http/server.py b/synapse/http/server.py index 2563858f3c..9314454af1 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -30,7 +30,6 @@ from typing import ( Iterable, Iterator, List, - NoReturn, Optional, Pattern, Tuple, @@ -340,7 +339,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - return _unrecognised_request_handler(request) + # A request with an unknown method (for a known endpoint) was received. + raise UnrecognizedRequestError(code=405) @abc.abstractmethod def _send_response( @@ -396,7 +396,6 @@ class DirectServeJsonResource(_AsyncResource): @attr.s(slots=True, frozen=True, auto_attribs=True) class _PathEntry: - pattern: Pattern callback: ServletCallback servlet_classname: str @@ -425,13 +424,14 @@ class JsonResource(DirectServeJsonResource): ): super().__init__(canonical_json, extract_context) self.clock = hs.get_clock() - self.path_regexs: Dict[bytes, List[_PathEntry]] = {} + # Map of path regex -> method -> callback. + self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {} self.hs = hs def register_paths( self, method: str, - path_patterns: Iterable[Pattern], + path_patterns: Iterable[Pattern[str]], callback: ServletCallback, servlet_classname: str, ) -> None: @@ -455,8 +455,8 @@ class JsonResource(DirectServeJsonResource): for path_pattern in path_patterns: logger.debug("Registering for %s %s", method, path_pattern.pattern) - self.path_regexs.setdefault(method_bytes, []).append( - _PathEntry(path_pattern, callback, servlet_classname) + self._routes.setdefault(path_pattern, {})[method_bytes] = _PathEntry( + callback, servlet_classname ) def _get_handler_for_request( @@ -478,14 +478,17 @@ class JsonResource(DirectServeJsonResource): # Loop through all the registered callbacks to check if the method # and path regex match - for path_entry in self.path_regexs.get(request_method, []): - m = path_entry.pattern.match(request_path) + for path_pattern, methods in self._routes.items(): + m = path_pattern.match(request_path) if m: - # We found a match! + # We found a matching path! + path_entry = methods.get(request_method) + if not path_entry: + raise UnrecognizedRequestError(code=405) return path_entry.callback, path_entry.servlet_classname, m.groupdict() - # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - return _unrecognised_request_handler, "unrecognised_request_handler", {} + # Huh. No one wanted to handle that? Fiiiiiine. + raise UnrecognizedRequestError(code=404) async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) @@ -567,19 +570,6 @@ class StaticResource(File): return super().render_GET(request) -def _unrecognised_request_handler(request: Request) -> NoReturn: - """Request handler for unrecognised requests - - This is a request handler suitable for return from - _get_handler_for_request. It actually just raises an - UnrecognizedRequestError. - - Args: - request: Unused, but passed in to match the signature of ServletCallback. - """ - raise UnrecognizedRequestError(code=404) - - class UnrecognizedRequestResource(resource.Resource): """ Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 8ef9a0dda8..6c7cf1b294 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -466,8 +466,16 @@ def init_tracer(hs: "HomeServer") -> None: STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name() ) + jaeger_config = hs.config.tracing.jaeger_config + tags = jaeger_config.setdefault("tags", {}) + + # tag the Synapse instance name so that it's an easy jumping + # off point into the logs. Can also be used to filter for an + # instance that is under load. + tags[SynapseTags.INSTANCE_NAME] = hs.get_instance_name() + config = JaegerConfig( - config=hs.config.tracing.jaeger_config, + config=jaeger_config, service_name=f"{hs.config.server.server_name} {instance_name_by_type}", scope_manager=LogContextScopeManager(), metrics_factory=PrometheusMetricsFactory(), diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index d9c0a98f44..f6a5bffb0f 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -22,6 +22,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -43,6 +44,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator +from synapse.types import SimpleJsonValue from synapse.types.state import StateFilter from synapse.util.caches import register_cache from synapse.util.metrics import measure_func @@ -148,7 +150,7 @@ class BulkPushRuleEvaluator: # little, we can skip fetching a huge number of push rules in large rooms. # This helps make joins and leaves faster. if event.type == EventTypes.Member: - local_users = [] + local_users: Sequence[str] = [] # We never notify a user about their own actions. This is enforced in # `_action_for_event_by_user` in the loop over `rules_by_user`, but we # do the same check here to avoid unnecessary DB queries. @@ -183,7 +185,6 @@ class BulkPushRuleEvaluator: if event.type == EventTypes.Member and event.membership == Membership.INVITE: invited = event.state_key if invited and self.hs.is_mine_id(invited) and invited not in local_users: - local_users = list(local_users) local_users.append(invited) if not local_users: @@ -256,13 +257,15 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]: + async def _related_events( + self, event: EventBase + ) -> Dict[str, Dict[str, SimpleJsonValue]]: """Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation Returns: Mapping of relation type to flattened events. """ - related_events: Dict[str, Dict[str, str]] = {} + related_events: Dict[str, Dict[str, SimpleJsonValue]] = {} if self._related_event_match_enabled: related_event_id = event.content.get("m.relates_to", {}).get("event_id") relation_type = event.content.get("m.relates_to", {}).get("rel_type") @@ -271,7 +274,10 @@ class BulkPushRuleEvaluator: related_event_id, allow_none=True ) if related_event is not None: - related_events[relation_type] = _flatten_dict(related_event) + related_events[relation_type] = _flatten_dict( + related_event, + msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + ) reply_event_id = ( event.content.get("m.relates_to", {}) @@ -286,7 +292,10 @@ class BulkPushRuleEvaluator: ) if related_event is not None: - related_events["m.in_reply_to"] = _flatten_dict(related_event) + related_events["m.in_reply_to"] = _flatten_dict( + related_event, + msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + ) # indicate that this is from a fallback relation. if relation_type == "m.thread" and event.content.get( @@ -405,7 +414,10 @@ class BulkPushRuleEvaluator: room_mention = mentions.get("room") is True evaluator = PushRuleEvaluator( - _flatten_dict(event), + _flatten_dict( + event, + msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key, + ), has_mentions, user_mentions, room_mention, @@ -416,6 +428,7 @@ class BulkPushRuleEvaluator: self._related_event_match_enabled, event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag + self.hs.config.experimental.msc3758_exact_event_match, ) users = rules_by_user.keys() @@ -492,13 +505,15 @@ StateGroup = Union[object, int] def _flatten_dict( d: Union[EventBase, Mapping[str, Any]], prefix: Optional[List[str]] = None, - result: Optional[Dict[str, str]] = None, -) -> Dict[str, str]: + result: Optional[Dict[str, SimpleJsonValue]] = None, + *, + msc3783_escape_event_match_key: bool = False, +) -> Dict[str, SimpleJsonValue]: """ Given a JSON dictionary (or event) which might contain sub dictionaries, flatten it into a single layer dictionary by combining the keys & sub-keys. - Any (non-dictionary), non-string value is dropped. + String, integer, boolean, and null values are kept. All others are dropped. Transforms: @@ -521,11 +536,22 @@ def _flatten_dict( if result is None: result = {} for key, value in d.items(): - if isinstance(value, str): - result[".".join(prefix + [key])] = value.lower() + if msc3783_escape_event_match_key: + # Escape periods in the key with a backslash (and backslashes with an + # extra backslash). This is since a period is used as a separator between + # nested fields. + key = key.replace("\\", "\\\\").replace(".", "\\.") + + if isinstance(value, (bool, str)) or type(value) is int or value is None: + result[".".join(prefix + [key])] = value elif isinstance(value, Mapping): # do not set `room_version` due to recursion considerations below - _flatten_dict(value, prefix=(prefix + [key]), result=result) + _flatten_dict( + value, + prefix=(prefix + [key]), + result=result, + msc3783_escape_event_match_key=msc3783_escape_event_match_key, + ) # `room_version` should only ever be set when looking at the top level of an event if ( diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 0d072c42a7..c134ccfb3d 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -15,7 +15,7 @@ import logging from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.constants import Direction from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet): timestamp and size. """ - PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$") + PATTERNS = [ + *admin_patterns("/media/delete$"), + # This URL kept around for legacy reasons, it is undesirable since it + # overlaps with the DeleteMediaByID servlet. + *admin_patterns("/media/(?P<server_name>[^/]*)/delete$"), + ] def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main @@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet): self.media_repository = hs.get_media_repository() async def on_POST( - self, request: SynapseRequest, server_name: str + self, request: SynapseRequest, server_name: Optional[str] = None ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet): errcode=Codes.INVALID_PARAM, ) - if self.server_name != server_name: + # This check is useless, we keep it for the legacy endpoint only. + if server_name is not None and self.server_name != server_name: raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media") logging.info( @@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer) ProtectMediaByID(hs).register(http_server) UnprotectMediaByID(hs).register(http_server) ListMediaInRoom(hs).register(http_server) - DeleteMediaByID(hs).register(http_server) + # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as + # their URL routes overlap. DeleteMediaByDateSize(hs).register(http_server) + DeleteMediaByID(hs).register(http_server) UserMediaRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index b9dca8ef3a..0c0bf540b9 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -1192,7 +1192,8 @@ class AccountDataRestServlet(RestServlet): if not await self._store.get_user_by_id(user_id): raise NotFoundError("User not found") - global_data, by_room_data = await self._store.get_account_data_for_user(user_id) + global_data = await self._store.get_global_account_data_for_user(user_id) + by_room_data = await self._store.get_room_account_data_for_user(user_id) return HTTPStatus.OK, { "account_data": { "global": global_data, diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 4373c73662..662f5bf762 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -415,6 +415,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): request, MsisdnRequestTokenBody ) msisdn = phone_number_to_msisdn(body.country, body.phone_number) + logger.info("Request #%s to verify ownership of %s", body.send_attempt, msisdn) if not await check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -444,6 +445,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): await self.hs.get_clock().sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} + logger.info("MSISDN %s is already in use by %s", msisdn, existing_user_id) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) if not self.hs.config.registration.account_threepid_delegate_msisdn: @@ -468,6 +470,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( body.send_attempt ) + logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret) return 200, ret @@ -734,12 +737,7 @@ class ThreepidUnbindRestServlet(RestServlet): # Attempt to unbind the threepid from an identity server. If id_server is None, try to # unbind from all identity servers this threepid has been added to in the past result = await self.identity_handler.try_unbind_threepid( - requester.user.to_string(), - { - "address": body.address, - "medium": body.medium, - "id_server": body.id_server, - }, + requester.user.to_string(), body.medium, body.address, body.id_server ) return 200, {"id_server_unbind_result": "success" if result else "no-support"} diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index f7081f638e..4e7ffdb555 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet): self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + """ + Retrieve the version information about the most current backup version (if any) + + It takes out an exclusive lock on this user's room_key backups, to ensure + clients only upload to the current backup. + + Returns 404 if the given version does not exist. + + GET /room_keys/version HTTP/1.1 + { + "version": "12345", + "algorithm": "m.megolm_backup.v1", + "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" + } + """ + requester = await self.auth.get_user_by_req(request, allow_guest=False) + user_id = requester.user.to_string() + + try: + info = await self.e2e_room_keys_handler.get_version_info(user_id) + except SynapseError as e: + if e.code == 404: + raise SynapseError(404, "No backup found", Codes.NOT_FOUND) + return 200, info + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given @@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$") + PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$") def __init__(self, hs: "HomeServer"): super().__init__() @@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet): self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() async def on_GET( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's - room_keys backup. If the version part is missing, returns info about the - most current backup version (if any) + room_keys backup. It takes out an exclusive lock on this user's room_key backups, to ensure clients only upload to the current backup. @@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet): return 200, info async def on_DELETE( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's - room_keys backup. If the version part is missing, deletes the most - current backup version (if any). Doesn't delete the actual room data. + room_keys backup. Doesn't delete the actual room data. DELETE /room_keys/version/12345 HTTP/1.1 HTTP/1.1 200 OK {} """ - if version is None: - raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) - requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() @@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} async def on_PUT( - self, request: SynapseRequest, version: Optional[str] + self, request: SynapseRequest, version: str ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - if version is None: - raise SynapseError( - 400, "No version specified to update", Codes.MISSING_PARAM - ) - await self.e2e_room_keys_handler.update_version(user_id, version, info) return 200, {} diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py index ca638755c7..dde08417a4 100644 --- a/synapse/rest/client/tags.py +++ b/synapse/rest/client/tags.py @@ -34,7 +34,9 @@ class TagListServlet(RestServlet): GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags") + PATTERNS = client_patterns( + "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$" + ) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index a5c3de192f..db25848744 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -46,10 +46,9 @@ from ._base import FileInfo, Responder from .filepath import MediaFilePaths if TYPE_CHECKING: + from synapse.rest.media.v1.storage_provider import StorageProvider from synapse.server import HomeServer - from .storage_provider import StorageProviderWrapper - logger = logging.getLogger(__name__) @@ -68,7 +67,7 @@ class MediaStorage: hs: "HomeServer", local_media_directory: str, filepaths: MediaFilePaths, - storage_providers: Sequence["StorageProviderWrapper"], + storage_providers: Sequence["StorageProvider"], ): self.hs = hs self.reactor = hs.get_reactor() @@ -360,7 +359,7 @@ class ReadableFileWrapper: clock: Clock path: str - async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None: + async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None: """Reads the file in chunks and calls the callback with each chunk.""" with open(self.path, "rb") as file: diff --git a/synapse/server.py b/synapse/server.py index 9d6d268f49..efc6b5f895 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -21,7 +21,7 @@ import abc import functools import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port @@ -144,10 +144,10 @@ if TYPE_CHECKING: from synapse.handlers.saml import SamlHandler -T = TypeVar("T", bound=Callable[..., Any]) +T = TypeVar("T") -def cache_in_self(builder: T) -> T: +def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]: """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and returning if so. If not, calls the given function and sets `self.foo` to it. @@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T: building = [False] @functools.wraps(builder) - def _get(self): + def _get(self: "HomeServer") -> T: try: return getattr(self, depname) except AttributeError: @@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T: return dep - # We cast here as we need to tell mypy that `_get` has the same signature as - # `builder`. - return cast(T, _get) + return _get class HomeServer(metaclass=abc.ABCMeta): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fdfb46ab82..4dc25df67e 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.events import EventBase -from synapse.events.snapshot import EventContext +from synapse.events.snapshot import ( + EventContext, + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.context import ContextResourceUsage from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 @@ -222,7 +226,7 @@ class StateHandler: return await ret.get_state(self._state_storage_controller, state_filter) async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: List[str] + self, room_id: str, latest_event_ids: Collection[str] ) -> Set[str]: """ Get the users IDs who are currently in a room. @@ -262,31 +266,31 @@ class StateHandler: state = await entry.get_state(self._state_storage_controller, StateFilter.all()) return await self.store.get_joined_hosts(room_id, state, entry) - async def compute_event_context( + async def calculate_context_info( self, event: EventBase, state_ids_before_event: Optional[StateMap[str]] = None, partial_state: Optional[bool] = None, - ) -> EventContext: - """Build an EventContext structure for a non-outlier event. - - (for an outlier, call EventContext.for_outlier directly) - - This works out what the current state should be for the event, and - generates a new state group if necessary. - - Args: - event: - state_ids_before_event: The event ids of the state before the event if - it can't be calculated from existing events. This is normally - only specified when receiving an event from federation where we - don't have the prev events, e.g. when backfilling. - partial_state: - `True` if `state_ids_before_event` is partial and omits non-critical - membership events. - `False` if `state_ids_before_event` is the full state. - `None` when `state_ids_before_event` is not provided. In this case, the - flag will be calculated based on `event`'s prev events. + state_group_before_event: Optional[int] = None, + ) -> UnpersistedEventContextBase: + """ + Calulates the contents of an unpersisted event context, other than the current + state group (which is either provided or calculated when the event context is persisted) + + state_ids_before_event: + The event ids of the full state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling or when the event + is being created for batch persisting. + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. + state_group_before_event: + the current state group at the time of event, if known Returns: The event context. @@ -294,7 +298,6 @@ class StateHandler: RuntimeError if `state_ids_before_event` is not provided and one or more prev events are missing or outliers. """ - assert not event.internal_metadata.is_outlier() # @@ -306,17 +309,6 @@ class StateHandler: state_group_before_event_prev_group = None deltas_to_state_group_before_event = None - # .. though we need to get a state group for it. - state_group_before_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=None, - delta_ids=None, - current_state_ids=state_ids_before_event, - ) - ) - # the partial_state flag must be provided assert partial_state is not None else: @@ -345,6 +337,7 @@ class StateHandler: logger.debug("calling resolve_state_groups from compute_event_context") # we've already taken into account partial state, so no need to wait for # complete state here. + entry = await self.resolve_state_groups_for_events( event.room_id, event.prev_event_ids(), @@ -383,18 +376,19 @@ class StateHandler: # if not event.is_state(): - return EventContext.with_state( + return UnpersistedEventContext( storage=self._storage_controllers, state_group_before_event=state_group_before_event, - state_group=state_group_before_event, + state_group_after_event=state_group_before_event, state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, + prev_group_for_state_group_before_event=state_group_before_event_prev_group, + delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, partial_state=partial_state, + state_map_before_event=state_ids_before_event, ) # - # otherwise, we'll need to create a new state group for after the event + # otherwise, we'll need to set up creating a new state group for after the event # key = (event.type, event.state_key) @@ -412,88 +406,60 @@ class StateHandler: delta_ids = {key: event.event_id} - state_group_after_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=None, - ) - ) - - return EventContext.with_state( + return UnpersistedEventContext( storage=self._storage_controllers, - state_group=state_group_after_event, state_group_before_event=state_group_before_event, + state_group_after_event=None, state_delta_due_to_event=delta_ids, - prev_group=state_group_before_event, - delta_ids=delta_ids, + prev_group_for_state_group_before_event=state_group_before_event_prev_group, + delta_ids_to_state_group_before_event=deltas_to_state_group_before_event, partial_state=partial_state, + state_map_before_event=state_ids_before_event, ) - async def compute_event_context_for_batched( + async def compute_event_context( self, event: EventBase, - state_ids_before_event: StateMap[str], - current_state_group: int, + state_ids_before_event: Optional[StateMap[str]] = None, + partial_state: Optional[bool] = None, ) -> EventContext: - """ - Generate an event context for an event that has not yet been persisted to the - database. Intended for use with events that are created to be persisted in a batch. - Args: - event: the event the context is being computed for - state_ids_before_event: a state map consisting of the state ids of the events - created prior to this event. - current_state_group: the current state group before the event. - """ - state_group_before_event_prev_group = None - deltas_to_state_group_before_event = None - - state_group_before_event = current_state_group - - # if the event is not state, we are set - if not event.is_state(): - return EventContext.with_state( - storage=self._storage_controllers, - state_group_before_event=state_group_before_event, - state_group=state_group_before_event, - state_delta_due_to_event={}, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - partial_state=False, - ) + """Build an EventContext structure for a non-outlier event. - # otherwise, we'll need to create a new state group for after the event - key = (event.type, event.state_key) + (for an outlier, call EventContext.for_outlier directly) - if state_ids_before_event is not None: - replaces = state_ids_before_event.get(key) + This works out what the current state should be for the event, and + generates a new state group if necessary. - if replaces and replaces != event.event_id: - event.unsigned["replaces_state"] = replaces + Args: + event: + state_ids_before_event: The event ids of the state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling. + partial_state: + `True` if `state_ids_before_event` is partial and omits non-critical + membership events. + `False` if `state_ids_before_event` is the full state. + `None` when `state_ids_before_event` is not provided. In this case, the + flag will be calculated based on `event`'s prev events. + entry: + A state cache entry for the resolved state across the prev events. We may + have already calculated this, so if it's available pass it in + Returns: + The event context. - delta_ids = {key: event.event_id} + Raises: + RuntimeError if `state_ids_before_event` is not provided and one or more + prev events are missing or outliers. + """ - state_group_after_event = ( - await self._state_storage_controller.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=None, - ) + unpersisted_context = await self.calculate_context_info( + event=event, + state_ids_before_event=state_ids_before_event, + partial_state=partial_state, ) - return EventContext.with_state( - storage=self._storage_controllers, - state_group=state_group_after_event, - state_group_before_event=state_group_before_event, - state_delta_due_to_event=delta_ids, - prev_group=state_group_before_event, - delta_ids=delta_ids, - partial_state=False, - ) + return await unpersisted_context.persist(event) @measure_func() async def resolve_state_groups_for_events( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 41d9111019..481fec72fe 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ + db_pool: DatabasePool + def __init__( self, database: DatabasePool, diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 52efd4a171..9d7a8a792f 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,6 +14,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Callable, @@ -23,7 +24,6 @@ from typing import ( List, Mapping, Optional, - Set, Tuple, ) @@ -527,7 +527,7 @@ class StateStorageController: ) return state_map.get(key) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms @@ -584,7 +584,7 @@ class StateStorageController: async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """ Get the current users in the room with their profiles. If the room is currently partial-stated, this will block until the room has diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e20c5c5302..feaa6cdd07 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -499,6 +499,7 @@ class DatabasePool: """ _TXN_ID = 0 + engine: BaseDatabaseEngine def __init__( self, diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 8a359d7eb8..95567826f2 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -21,6 +21,7 @@ from typing import ( FrozenSet, Iterable, List, + Mapping, Optional, Tuple, cast, @@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return self._account_data_id_gen.get_current_token() @cached() - async def get_account_data_for_user( + async def get_global_account_data_for_user( self, user_id: str - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: + ) -> Mapping[str, JsonDict]: """ - Get all the client account_data for a user. + Get all the global client account_data for a user. If experimental MSC3391 support is enabled, any entries with an empty content body are excluded; as this means they have been deleted. Args: user_id: The user to get the account_data for. + Returns: - A 2-tuple of a dict of global account_data and a dict mapping from - room_id string to per room account_data dicts. + The global account_data. """ - def get_account_data_for_user_txn( + def get_global_account_data_for_user( txn: LoggingTransaction, - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: + ) -> Dict[str, JsonDict]: # The 'content != '{}' condition below prevents us from using # `simple_select_list_txn` here, as it doesn't support conditions # other than 'equals'. @@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn.execute(sql, (user_id,)) rows = self.db_pool.cursor_to_dict(txn) - global_account_data = { + return { row["account_data_type"]: db_to_json(row["content"]) for row in rows } + return await self.db_pool.runInteraction( + "get_global_account_data_for_user", get_global_account_data_for_user + ) + + @cached() + async def get_room_account_data_for_user( + self, user_id: str + ) -> Mapping[str, Mapping[str, JsonDict]]: + """ + Get all of the per-room client account_data for a user. + + If experimental MSC3391 support is enabled, any entries with an empty + content body are excluded; as this means they have been deleted. + + Args: + user_id: The user to get the account_data for. + + Returns: + A dict mapping from room_id string to per-room account_data dicts. + """ + + def get_room_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, JsonDict]]: # The 'content != '{}' condition below prevents us from using # `simple_select_list_txn` here, as it doesn't support conditions # other than 'equals'. @@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) room_data[row["account_data_type"]] = db_to_json(row["content"]) - return global_account_data, by_room + return by_room return await self.db_pool.runInteraction( - "get_account_data_for_user", get_account_data_for_user_txn + "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn ) @cached(num_args=2, max_entries=5000, tree=True) @@ -215,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) @cached(num_args=2, tree=True) async def get_account_data_for_room( self, user_id: str, room_id: str - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get all the client account_data for a user for a room. Args: @@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) "get_updated_room_account_data", get_updated_room_account_data_txn ) - async def get_updated_account_data_for_user( + async def get_updated_global_account_data_for_user( self, user_id: str, stream_id: int - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - """Get all the client account_data for a that's changed for a user + ) -> Dict[str, JsonDict]: + """Get all the global account_data that's changed for a user. Args: user_id: The user to get the account_data for. stream_id: The point in the stream since which to get updates + Returns: - A deferred pair of a dict of global account_data and a dict - mapping from room_id string to per room account_data dicts. + A dict of global account_data. """ - def get_updated_account_data_for_user_txn( + def get_updated_global_account_data_for_user( txn: LoggingTransaction, - ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: - sql = ( - "SELECT account_data_type, content FROM account_data" - " WHERE user_id = ? AND stream_id > ?" - ) - + ) -> Dict[str, JsonDict]: + sql = """ + SELECT account_data_type, content FROM account_data + WHERE user_id = ? AND stream_id > ? + """ txn.execute(sql, (user_id, stream_id)) - global_account_data = {row[0]: db_to_json(row[1]) for row in txn} + return {row[0]: db_to_json(row[1]) for row in txn} - sql = ( - "SELECT room_id, account_data_type, content FROM room_account_data" - " WHERE user_id = ? AND stream_id > ?" - ) + changed = self._account_data_stream_cache.has_entity_changed( + user_id, int(stream_id) + ) + if not changed: + return {} + + return await self.db_pool.runInteraction( + "get_updated_global_account_data_for_user", + get_updated_global_account_data_for_user, + ) + + async def get_updated_room_account_data_for_user( + self, user_id: str, stream_id: int + ) -> Dict[str, Dict[str, JsonDict]]: + """Get all the room account_data that's changed for a user. + Args: + user_id: The user to get the account_data for. + stream_id: The point in the stream since which to get updates + + Returns: + A dict mapping from room_id string to per room account_data dicts. + """ + + def get_updated_room_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Dict[str, Dict[str, JsonDict]]: + sql = """ + SELECT room_id, account_data_type, content FROM room_account_data + WHERE user_id = ? AND stream_id > ? + """ txn.execute(sql, (user_id, stream_id)) account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} @@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) - return global_account_data, account_data_by_room + return account_data_by_room changed = self._account_data_stream_cache.has_entity_changed( user_id, int(stream_id) ) if not changed: - return {}, {} + return {} return await self.db_pool.runInteraction( - "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + "get_updated_room_account_data_for_user", + get_updated_room_account_data_for_user_txn, ) @cached(max_entries=5000, iterable=True) @@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self.get_global_account_data_by_type_for_user.invalidate( (row.user_id, row.data_type) ) - self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_global_account_data_for_user.invalidate((row.user_id,)) + self.get_room_account_data_for_user.invalidate((row.user_id,)) self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) self.get_account_data_for_room_and_type.invalidate( (row.user_id, row.room_id, row.data_type) @@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_room_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), content @@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return None self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_room_account_data_for_user.invalidate((user_id,)) self.get_account_data_for_room.invalidate((user_id, room_id)) self.get_account_data_for_room_and_type.prefill( (user_id, room_id, account_data_type), {} @@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) ) self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.invalidate( (user_id, account_data_type) ) @@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) return None self._account_data_stream_cache.entity_has_changed(user_id, next_id) - self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_for_user.invalidate((user_id,)) self.get_global_account_data_by_type_for_user.prefill( (user_id, account_data_type), {} ) @@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) txn, self.get_account_data_for_room_and_type, (user_id,) ) self._invalidate_cache_and_stream( - txn, self.get_account_data_for_user, (user_id,) + txn, self.get_global_account_data_for_user, (user_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_room_account_data_for_user, (user_id,) ) self._invalidate_cache_and_stream( txn, self.get_global_account_data_by_type_for_user, (user_id,) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index 5fb152c4ff..484db175d0 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore): room_id: str, app_service: "ApplicationService", cache_context: _CacheContext, - ) -> List[str]: + ) -> Sequence[str]: """ Get all users in a room that the appservice controls. diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e8b6cc6b80..1ca66d57d4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -21,6 +21,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -100,6 +101,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ("device_lists_outbound_pokes", "stream_id"), ("device_lists_changes_in_room", "stream_id"), ("device_lists_remote_pending", "stream_id"), + ("device_lists_changes_converted_stream_position", "stream_id"), ], is_writer=hs.config.worker.worker_app is None, ) @@ -201,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() - async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int: + async def count_devices_by_users( + self, user_ids: Optional[Collection[str]] = None + ) -> int: """Retrieve number of all devices of given users. Only returns number of devices that are not marked as hidden. @@ -212,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ def count_devices_by_users_txn( - txn: LoggingTransaction, user_ids: List[str] + txn: LoggingTransaction, user_ids: Collection[str] ) -> int: sql = """ SELECT count(*) @@ -745,42 +749,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @trace @cancellable async def get_user_devices_from_cache( - self, query_list: List[Tuple[str, Optional[str]]] - ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: + self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. Args: - query_list: List of (user_id, device_ids), if device_ids is - falsey then return all device ids for that user. + user_ids: users which should have all device IDs returned + user_and_device_ids: List of (user_id, device_ids) Returns: A tuple of (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info. """ - user_ids = {user_id for user_id, _ in query_list} - user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids} + user_map = await self.get_device_list_last_stream_id_for_remotes( + list(unique_user_ids) + ) # We go and check if any of the users need to have their device lists # resynced. If they do then we remove them from the cached list. users_needing_resync = await self.get_user_ids_requiring_device_list_resync( - user_ids + unique_user_ids ) user_ids_in_cache = { user_id for user_id, stream_id in user_map.items() if stream_id } - users_needing_resync - user_ids_not_in_cache = user_ids - user_ids_in_cache - - results: Dict[str, Dict[str, JsonDict]] = {} - for user_id, device_id in query_list: - if user_id not in user_ids_in_cache: - continue + user_ids_not_in_cache = unique_user_ids - user_ids_in_cache - if device_id: - device = await self._get_cached_user_device(user_id, device_id) - results.setdefault(user_id, {})[device_id] = device - else: + # First fetch all the users which all devices are to be returned. + results: Dict[str, Mapping[str, JsonDict]] = {} + for user_id in user_ids: + if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) + # Then fetch all device-specific requests, but skip users we've already + # fetched all devices for. + device_specific_results: Dict[str, Dict[str, JsonDict]] = {} + for user_id, device_id in user_and_device_ids: + if user_id in user_ids_in_cache and user_id not in user_ids: + device = await self._get_cached_user_device(user_id, device_id) + device_specific_results.setdefault(user_id, {})[device_id] = device + results.update(device_specific_results) set_tag("in_cache", str(results)) set_tag("not_in_cache", str(user_ids_not_in_cache)) @@ -798,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]: + async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 5903fdaf00..44aa181174 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple import attr @@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore): ) @cached(max_entries=5000) - async def get_aliases_for_room(self, room_id: str) -> List[str]: + async def get_aliases_for_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( "room_aliases", {"room_id": room_id}, diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c4ac6c33ba..2c2d145666 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -20,7 +20,9 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, Union, cast, @@ -260,7 +262,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker for batch in batch_iter(signature_query, 50): cross_sigs_result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures", + "get_e2e_cross_signing_signatures_for_devices", self._get_e2e_cross_signing_signatures_for_devices_txn, batch, ) @@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cached(max_entries=10000) async def get_e2e_unused_fallback_key_types( self, user_id: str, device_id: str - ) -> List[str]: + ) -> Sequence[str]: """Returns the fallback key types that have an unused key. Args: @@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Dict[str, JsonDict]]], result) + return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) def _get_bare_e2e_cross_signing_keys_bulk_txn( self, @@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Dict[str, JsonDict]]]: + ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids) if from_user_id: - result = await self.db_pool.runInteraction( - "get_e2e_cross_signing_signatures", - self._get_e2e_cross_signing_signatures_txn, - result, - from_user_id, + result = cast( + Dict[str, Optional[Mapping[str, JsonDict]]], + await self.db_pool.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ), ) return result diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index bbee02ab18..ca780cca36 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -22,6 +22,7 @@ from typing import ( Iterable, List, Optional, + Sequence, Set, Tuple, cast, @@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: + async def get_max_depth_of( + self, event_ids: Collection[str] + ) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) @cached(max_entries=5000, iterable=True) - async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]: + async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]: return await self.db_pool.simple_select_onecol( table="event_forward_extremities", keyvalues={"room_id": room_id}, @@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @cancellable async def get_forward_extremities_for_room_at_stream_ordering( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas @cached(max_entries=5000, num_args=2) async def _get_forward_extremeties_for_room( self, room_id: str, stream_ordering: int - ) -> List[str]: + ) -> Sequence[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 3a0c370fde..eeccf5db24 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -203,11 +203,18 @@ class RoomNotifCounts: # Map of thread ID to the notification counts. threads: Dict[str, NotifCounts] + @staticmethod + def empty() -> "RoomNotifCounts": + return _EMPTY_ROOM_NOTIF_COUNTS + def __len__(self) -> int: # To properly account for the amount of space in any caches. return len(self.threads) + 1 +_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {}) + + def _serialize_action( actions: Collection[Union[Mapping, str]], is_highlight: bool ) -> str: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1536937b67..7996cbb557 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -16,7 +16,6 @@ import itertools import logging from collections import OrderedDict -from http import HTTPStatus from typing import ( TYPE_CHECKING, Any, @@ -26,7 +25,6 @@ from typing import ( Iterable, List, Optional, - Sequence, Set, Tuple, ) @@ -36,7 +34,7 @@ from prometheus_client import Counter import synapse.metrics from synapse.api.constants import EventContentFields, EventTypes, RelationTypes -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext @@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import JsonDict, StateMap, get_domain_from_id +from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.stringutils import non_null_str_or_none @@ -72,24 +70,6 @@ event_counter = Counter( ) -class PartialStateConflictError(SynapseError): - """An internal error raised when attempting to persist an event with partial state - after the room containing the event has been un-partial stated. - - This error should be handled by recomputing the event context and trying again. - - This error has an HTTP status code so that it can be transported over replication. - It should not be exposed to clients. - """ - - def __init__(self) -> None: - super().__init__( - HTTPStatus.CONFLICT, - msg="Cannot persist partial state event in un-partial stated room", - errcode=Codes.UNKNOWN, - ) - - @attr.s(slots=True, auto_attribs=True) class DeltaState: """Deltas to use to update the `current_state_events` table. @@ -306,7 +286,7 @@ class PersistEventsStore: # The set of event_ids to return. This includes all soft-failed events # and their prev events. - existing_prevs = set() + existing_prevs: Set[str] = set() def _get_prevs_before_rejected_txn( txn: LoggingTransaction, batch: Collection[str] @@ -571,7 +551,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, Sequence[str]], + event_to_auth_chain: Dict[str, StrCollection], ) -> None: """Calculate the chain cover index for the given events. @@ -865,7 +845,7 @@ class PersistEventsStore: event_chain_id_gen: SequenceGenerator, event_to_room_id: Dict[str, str], event_to_types: Dict[str, Tuple[str, str]], - event_to_auth_chain: Dict[str, Sequence[str]], + event_to_auth_chain: Dict[str, StrCollection], events_to_calc_chain_id_for: Set[str], chain_map: Dict[str, Tuple[int, int]], ) -> Dict[str, Tuple[int, int]]: diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index b9d3c36d60..584536111d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast import attr @@ -29,7 +29,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.types import Cursor -from synapse.types import JsonDict +from synapse.types import JsonDict, StrCollection if TYPE_CHECKING: from synapse.server import HomeServer @@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): self.event_chain_id_gen, # type: ignore[attr-defined] event_to_room_id, event_to_types, - cast(Dict[str, Sequence[str]], event_to_auth_chain), + cast(Dict[str, StrCollection], event_to_auth_chain), ) return _CalculateChainCover( diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index db9a24db5e..4b1061e6d7 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): return await self.db_pool.runInteraction("count_users", _count_users) @cached(num_args=0) - async def get_monthly_active_count_by_service(self) -> Dict[str, int]: + async def get_monthly_active_count_by_service(self) -> Mapping[str, int]: """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 29972d5204..dddf49c2d5 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -21,7 +21,9 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, + Sequence, Tuple, cast, ) @@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> Sequence[JsonDict]: """Get receipts for a single room for sending to clients. Args: @@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[JsonDict]: + ) -> Sequence[JsonDict]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: @@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, List[JsonDict]]: + ) -> Dict[str, Sequence[JsonDict]]: if not room_ids: return {} @@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Dict[str, JsonDict]: + ) -> Mapping[str, JsonDict]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 31f0f2bd3d..9a55e17624 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -16,7 +16,7 @@ import logging import random import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import attr @@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): ) @cached() - async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]: """Deprecated: use get_userinfo_by_id instead""" def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]: diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 0018d6f7ab..fa3266c081 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -22,6 +22,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: + ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore): return result is not None @cached() - async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]: + async def get_aggregation_groups_for_event( + self, event_id: str + ) -> Sequence[JsonDict]: raise NotImplementedError() @cachedList( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index ea6a5e2f34..694a5b802c 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -24,6 +24,7 @@ from typing import ( List, Mapping, Optional, + Sequence, Set, Tuple, Union, @@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return self._known_servers_count @cached(max_entries=100000, iterable=True) - async def get_users_in_room(self, room_id: str) -> List[str]: + async def get_users_in_room(self, room_id: str) -> Sequence[str]: """Returns a list of users in the room. Will return inaccurate results for rooms with partial state, since the state for @@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached() - def get_user_in_room_with_profile( - self, room_id: str, user_id: str - ) -> Dict[str, ProfileInfo]: + def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo: raise NotImplementedError() @cachedList( @@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000, iterable=True) async def get_users_in_room_with_profiles( self, room_id: str - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for all users in a given room. The profile information comes directly from this room's `m.room.member` @@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) @cached(max_entries=100000) - async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: + async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]: """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: @@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached() async def get_invited_rooms_for_local_user( self, user_id: str - ) -> List[RoomsForUser]: + ) -> Sequence[RoomsForUser]: """Get all the rooms the *local* user is invited to. Args: @@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return results @cached(iterable=True) - async def get_local_users_in_room(self, room_id: str) -> List[str]: + async def get_local_users_in_room(self, room_id: str) -> Sequence[str]: """ Retrieves a list of the current roommembers who are local to the server. """ @@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user(user_id) - user_who_share_room = set() + user_who_share_room: Set[str] = set() for room_id in room_ids: user_ids = await self.get_users_in_room(room_id) user_who_share_room.update(user_ids) @@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]: """Get current hosts in room based on current state.""" # First we check if we already have `get_users_in_room` in the cache, as diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index 05da15074a..5dcb1fc0b5 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Collection, Dict, List, Tuple +from typing import Collection, Dict, List, Mapping, Tuple from unpaddedbase64 import encode_base64 @@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList class SignatureWorkerStore(EventsWorkerStore): @cached() - def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]: + def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]: # This is a dummy function to allow get_event_reference_hashes # to use its cache raise NotImplementedError() @@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore): ) async def get_event_reference_hashes( self, event_ids: Collection[str] - ) -> Dict[str, Dict[str, bytes]]: + ) -> Mapping[str, Mapping[str, bytes]]: """Get all hashes for given events. Args: diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index d5500cdd47..c149a9eacb 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import Any, Dict, Iterable, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.tcp.streams import AccountDataStream @@ -32,7 +32,9 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: + async def get_tags_for_user( + self, user_id: str + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for a user. @@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore): async def get_updated_tags( self, user_id: str, stream_id: int - ) -> Dict[str, Dict[str, JsonDict]]: + ) -> Mapping[str, Mapping[str, JsonDict]]: """Get all the tags for the rooms where the tags have changed since the given version diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 14ef5b040d..f6a6fd4079 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -16,9 +16,9 @@ import logging import re from typing import ( TYPE_CHECKING, - Dict, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) @cached() - async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]: + async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]: return await self.db_pool.simple_select_one( table="user_directory", keyvalues={"user_id": user_id}, diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 3acdb39da7..6c335a9315 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingDatabaseConnection -from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine +from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION from synapse.storage.types import Cursor @@ -108,9 +108,14 @@ def prepare_database( # so we start one before running anything. This ensures that any upgrades # are either applied completely, or not at all. # - # (psycopg2 automatically starts a transaction as soon as we run any statements - # at all, so this is redundant but harmless there.) - cur.execute("BEGIN TRANSACTION") + # psycopg2 does not automatically start transactions when in autocommit mode. + # While it is technically harmless to nest transactions in postgres, doing so + # results in a warning in Postgres' logs per query. And we'd rather like to + # avoid doing that. + if isinstance(database_engine, Sqlite3Engine) or ( + isinstance(database_engine, PostgresEngine) and db_conn.autocommit + ): + cur.execute("BEGIN TRANSACTION") logger.info("%r: Checking existing schema version", databases) version_info = _get_or_create_schema_state(cur, database_engine) diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index f82d1cfc29..52e366c8ae 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -69,6 +69,8 @@ StateMap = Mapping[StateKey, T] MutableStateMap = MutableMapping[StateKey, T] # JSON types. These could be made stronger, but will do for now. +# A "simple" (canonical) JSON value. +SimpleJsonValue = Optional[Union[str, int, bool]] # A JSON-serialisable dict. JsonDict = Dict[str, Any] # A JSON-serialisable mapping; roughly speaking an immutable JSONDict. |