diff options
Diffstat (limited to 'synapse')
49 files changed, 1727 insertions, 762 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 829061c870..5f0f34119b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -198,6 +198,12 @@ class EventContentFields: # cf https://github.com/matrix-org/matrix-doc/pull/1772 ROOM_TYPE = "type" + # The creator of the room, as used in `m.room.create` events. + ROOM_CREATOR = "creator" + + # Used in m.room.guest_access events. + GUEST_ACCESS = "guest_access" + # Used on normal messages to indicate they were historically imported after the fact MSC2716_HISTORICAL = "org.matrix.msc2716.historical" # For "insertion" events to indicate what the next chunk ID should be in @@ -232,5 +238,11 @@ class HistoryVisibility: WORLD_READABLE = "world_readable" +class GuestAccess: + CAN_JOIN = "can_join" + # anything that is not "can_join" is considered "forbidden", but for completeness: + FORBIDDEN = "forbidden" + + class ReadReceiptEventFields: MSC2285_HIDDEN = "org.matrix.msc2285.hidden" diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 39e28aff9f..89bda00090 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import atexit import gc import logging import os @@ -36,6 +37,7 @@ from synapse.api.constants import MAX_PDU_SIZE from synapse.app import check_bind_error from synapse.app.phone_stats_home import start_phone_stats_home from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import ManholeConfig from synapse.crypto import context_factory from synapse.events.presence_router import load_legacy_presence_router from synapse.events.spamcheck import load_legacy_spam_checkers @@ -229,7 +231,12 @@ def listen_metrics(bind_addresses, port): start_http_server(port, addr=host, registry=RegistryProxy) -def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: dict): +def listen_manhole( + bind_addresses: Iterable[str], + port: int, + manhole_settings: ManholeConfig, + manhole_globals: dict, +): # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -244,7 +251,7 @@ def listen_manhole(bind_addresses: Iterable[str], port: int, manhole_globals: di listen_tcp( bind_addresses, port, - manhole(username="matrix", password="rabbithole", globals=manhole_globals), + manhole(settings=manhole_settings, globals=manhole_globals), ) @@ -403,6 +410,12 @@ async def start(hs: "HomeServer"): gc.collect() gc.freeze() + # Speed up shutdowns by freezing all allocated objects. This moves everything + # into the permanent generation and excludes them from the final GC. + # Unfortunately only works on Python 3.7 + if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7): + atexit.register(gc.freeze) + def setup_sentry(hs): """Enable sentry integration, if enabled in configuration diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 9b71dd75e6..2eb8d5a79c 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,7 +395,10 @@ class GenericWorkerServer(HomeServer): self._listen_http(listener) elif listener.type == "manhole": _base.listen_manhole( - listener.bind_addresses, listener.port, manhole_globals={"hs": self} + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, ) elif listener.type == "metrics": if not self.config.enable_metrics: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 7dae163c1a..708db86f5d 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -291,7 +291,10 @@ class SynapseHomeServer(HomeServer): ) elif listener.type == "manhole": _base.listen_manhole( - listener.bind_addresses, listener.port, manhole_globals={"hs": self} + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, ) elif listener.type == "replication": services = listen_tcp( diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 1f42a51857..442f1b9ac0 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -30,6 +30,7 @@ from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig from .modules import ModulesConfig +from .oembed import OembedConfig from .oidc import OIDCConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig @@ -65,6 +66,7 @@ class HomeServerConfig(RootConfig): LoggingConfig, RatelimitConfig, ContentRepositoryConfig, + OembedConfig, CaptchaConfig, VoipConfig, RegistrationConfig, diff --git a/synapse/config/oembed.py b/synapse/config/oembed.py new file mode 100644 index 0000000000..09267b5eef --- /dev/null +++ b/synapse/config/oembed.py @@ -0,0 +1,180 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import re +from typing import Any, Dict, Iterable, List, Pattern +from urllib import parse as urlparse + +import attr +import pkg_resources + +from synapse.types import JsonDict + +from ._base import Config, ConfigError +from ._util import validate_config + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class OEmbedEndpointConfig: + # The API endpoint to fetch. + api_endpoint: str + # The patterns to match. + url_patterns: List[Pattern] + + +class OembedConfig(Config): + """oEmbed Configuration""" + + section = "oembed" + + def read_config(self, config, **kwargs): + oembed_config: Dict[str, Any] = config.get("oembed") or {} + + # A list of patterns which will be used. + self.oembed_patterns: List[OEmbedEndpointConfig] = list( + self._parse_and_validate_providers(oembed_config) + ) + + def _parse_and_validate_providers( + self, oembed_config: dict + ) -> Iterable[OEmbedEndpointConfig]: + """Extract and parse the oEmbed providers from the given JSON file. + + Returns a generator which yields the OidcProviderConfig objects + """ + # Whether to use the packaged providers.json file. + if not oembed_config.get("disable_default_providers") or False: + providers = json.load( + pkg_resources.resource_stream("synapse", "res/providers.json") + ) + yield from self._parse_and_validate_provider( + providers, config_path=("oembed",) + ) + + # The JSON files which includes additional provider information. + for i, file in enumerate(oembed_config.get("additional_providers") or []): + # TODO Error checking. + with open(file) as f: + providers = json.load(f) + + yield from self._parse_and_validate_provider( + providers, + config_path=( + "oembed", + "additional_providers", + f"<item {i}>", + ), + ) + + def _parse_and_validate_provider( + self, providers: List[JsonDict], config_path: Iterable[str] + ) -> Iterable[OEmbedEndpointConfig]: + # Ensure it is the proper form. + validate_config( + _OEMBED_PROVIDER_SCHEMA, + providers, + config_path=config_path, + ) + + # Parse it and yield each result. + for provider in providers: + # Each provider might have multiple API endpoints, each which + # might have multiple patterns to match. + for endpoint in provider["endpoints"]: + api_endpoint = endpoint["url"] + patterns = [ + self._glob_to_pattern(glob, config_path) + for glob in endpoint["schemes"] + ] + yield OEmbedEndpointConfig(api_endpoint, patterns) + + def _glob_to_pattern(self, glob: str, config_path: Iterable[str]) -> Pattern: + """ + Convert the glob into a sane regular expression to match against. The + rules followed will be slightly different for the domain portion vs. + the rest. + + 1. The scheme must be one of HTTP / HTTPS (and have no globs). + 2. The domain can have globs, but we limit it to characters that can + reasonably be a domain part. + TODO: This does not attempt to handle Unicode domain names. + TODO: The domain should not allow wildcard TLDs. + 3. Other parts allow a glob to be any one, or more, characters. + """ + results = urlparse.urlparse(glob) + + # Ensure the scheme does not have wildcards (and is a sane scheme). + if results.scheme not in {"http", "https"}: + raise ConfigError(f"Insecure oEmbed scheme: {results.scheme}", config_path) + + pattern = urlparse.urlunparse( + [ + results.scheme, + re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), + ] + + [re.escape(part).replace("\\*", ".+") for part in results[2:]] + ) + return re.compile(pattern) + + def generate_config_section(self, **kwargs): + return """\ + # oEmbed allows for easier embedding content from a website. It can be + # used for generating URLs previews of services which support it. + # + oembed: + # A default list of oEmbed providers is included with Synapse. + # + # Uncomment the following to disable using these default oEmbed URLs. + # Defaults to 'false'. + # + #disable_default_providers: true + + # Additional files with oEmbed configuration (each should be in the + # form of providers.json). + # + # By default, this list is empty (so only the default providers.json + # is used). + # + #additional_providers: + # - oembed/my_providers.json + """ + + +_OEMBED_PROVIDER_SCHEMA = { + "type": "array", + "items": { + "type": "object", + "properties": { + "provider_name": {"type": "string"}, + "provider_url": {"type": "string"}, + "endpoints": { + "type": "array", + "items": { + "type": "object", + "properties": { + "schemes": { + "type": "array", + "items": {"type": "string"}, + }, + "url": {"type": "string"}, + "formats": {"type": "array", "items": {"type": "string"}}, + "discovery": {"type": "boolean"}, + }, + "required": ["schemes", "url"], + }, + }, + }, + "required": ["provider_name", "provider_url", "endpoints"], + }, +} diff --git a/synapse/config/server.py b/synapse/config/server.py index d2c900f50c..7b9109a592 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -25,11 +25,14 @@ import attr import yaml from netaddr import AddrFormatError, IPNetwork, IPSet +from twisted.conch.ssh.keys import Key + from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError +from ._util import validate_config logger = logging.Logger(__name__) @@ -216,6 +219,16 @@ class ListenerConfig: http_options = attr.ib(type=Optional[HttpListenerConfig], default=None) +@attr.s(frozen=True) +class ManholeConfig: + """Object describing the configuration of the manhole""" + + username = attr.ib(type=str, validator=attr.validators.instance_of(str)) + password = attr.ib(type=str, validator=attr.validators.instance_of(str)) + priv_key = attr.ib(type=Optional[Key]) + pub_key = attr.ib(type=Optional[Key]) + + class ServerConfig(Config): section = "server" @@ -649,6 +662,41 @@ class ServerConfig(Config): ) ) + manhole_settings = config.get("manhole_settings") or {} + validate_config( + _MANHOLE_SETTINGS_SCHEMA, manhole_settings, ("manhole_settings",) + ) + + manhole_username = manhole_settings.get("username", "matrix") + manhole_password = manhole_settings.get("password", "rabbithole") + manhole_priv_key_path = manhole_settings.get("ssh_priv_key_path") + manhole_pub_key_path = manhole_settings.get("ssh_pub_key_path") + + manhole_priv_key = None + if manhole_priv_key_path is not None: + try: + manhole_priv_key = Key.fromFile(manhole_priv_key_path) + except Exception as e: + raise ConfigError( + f"Failed to read manhole private key file {manhole_priv_key_path}" + ) from e + + manhole_pub_key = None + if manhole_pub_key_path is not None: + try: + manhole_pub_key = Key.fromFile(manhole_pub_key_path) + except Exception as e: + raise ConfigError( + f"Failed to read manhole public key file {manhole_pub_key_path}" + ) from e + + self.manhole_settings = ManholeConfig( + username=manhole_username, + password=manhole_password, + priv_key=manhole_priv_key, + pub_key=manhole_pub_key, + ) + metrics_port = config.get("metrics_port") if metrics_port: logger.warning(METRICS_PORT_WARNING) @@ -715,7 +763,7 @@ class ServerConfig(Config): if not isinstance(templates_config, dict): raise ConfigError("The 'templates' section must be a dictionary") - self.custom_template_directory = templates_config.get( + self.custom_template_directory: Optional[str] = templates_config.get( "custom_template_directory" ) if self.custom_template_directory is not None and not isinstance( @@ -727,7 +775,13 @@ class ServerConfig(Config): return any(listener.tls for listener in self.listeners) def generate_config_section( - self, server_name, data_dir_path, open_private_ports, listeners, **kwargs + self, + server_name, + data_dir_path, + open_private_ports, + listeners, + config_dir_path, + **kwargs, ): ip_range_blacklist = "\n".join( " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST @@ -1068,6 +1122,24 @@ class ServerConfig(Config): # bind_addresses: ['::1', '127.0.0.1'] # type: manhole + # Connection settings for the manhole + # + manhole_settings: + # The username for the manhole. This defaults to 'matrix'. + # + #username: manhole + + # The password for the manhole. This defaults to 'rabbithole'. + # + #password: mypassword + + # The private and public SSH key pair used to encrypt the manhole traffic. + # If these are left unset, then hardcoded and non-secret keys are used, + # which could allow traffic to be intercepted if sent over a public network. + # + #ssh_priv_key_path: %(config_dir_path)s/id_rsa + #ssh_pub_key_path: %(config_dir_path)s/id_rsa.pub + # Forward extremities can build up in a room due to networking delays between # homeservers. Once this happens in a large room, calculation of the state of # that room can become quite expensive. To mitigate this, once the number of @@ -1436,3 +1508,14 @@ def _warn_if_webclient_configured(listeners: Iterable[ListenerConfig]) -> None: if name == "webclient": logger.warning(NO_MORE_WEB_CLIENT_WARNING) return + + +_MANHOLE_SETTINGS_SCHEMA = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "password": {"type": "string"}, + "ssh_priv_key_path": {"type": "string"}, + "ssh_pub_key_path": {"type": "string"}, + }, +} diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c3a0c10499..b63a1afe93 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -216,21 +216,18 @@ def check( def _check_size_limits(event: EventBase) -> None: - def too_big(field): - raise EventSizeError("%s too large" % (field,)) - if len(event.user_id) > 255: - too_big("user_id") + raise EventSizeError("'user_id' too large") if len(event.room_id) > 255: - too_big("room_id") + raise EventSizeError("'room_id' too large") if event.is_state() and len(event.state_key) > 255: - too_big("state_key") + raise EventSizeError("'state_key' too large") if len(event.type) > 255: - too_big("type") + raise EventSizeError("'type' too large") if len(event.event_id) > 255: - too_big("event_id") + raise EventSizeError("'event_id' too large") if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE: - too_big("event") + raise EventSizeError("event too large") def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool: diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 6a05a65305..955cfa2207 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -15,10 +15,7 @@ import logging from typing import TYPE_CHECKING, Optional -import synapse.types -from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter -from synapse.types import UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -115,68 +112,3 @@ class BaseHandler: burst_count=burst_count, update=update, ) - - async def maybe_kick_guest_users(self, event, context=None): - # Technically this function invalidates current_state by changing it. - # Hopefully this isn't that important to the caller. - if event.type == EventTypes.GuestAccess: - guest_access = event.content.get("guest_access", "forbidden") - if guest_access != "can_join": - if context: - current_state_ids = await context.get_current_state_ids() - current_state_dict = await self.store.get_events( - list(current_state_ids.values()) - ) - current_state = list(current_state_dict.values()) - else: - current_state_map = await self.state_handler.get_current_state( - event.room_id - ) - current_state = list(current_state_map.values()) - - logger.info("maybe_kick_guest_users %r", current_state) - await self.kick_guest_users(current_state) - - async def kick_guest_users(self, current_state): - for member_event in current_state: - try: - if member_event.type != EventTypes.Member: - continue - - target_user = UserID.from_string(member_event.state_key) - if not self.hs.is_mine(target_user): - continue - - if member_event.content["membership"] not in { - Membership.JOIN, - Membership.INVITE, - }: - continue - - if ( - "kind" not in member_event.content - or member_event.content["kind"] != "guest" - ): - continue - - # We make the user choose to leave, rather than have the - # event-sender kick them. This is partially because we don't - # need to worry about power levels, and partially because guest - # users are a concept which doesn't hugely work over federation, - # and having homeservers have their own users leave keeps more - # of that decision-making and control local to the guest-having - # homeserver. - requester = synapse.types.create_requester( - target_user, is_guest=True, authenticated_entity=self.server_name - ) - handler = self.hs.get_room_member_handler() - await handler.update_membership( - requester, - target_user, - member_event.room_id, - "leave", - ratelimit=False, - require_consent=False, - ) - except Exception as e: - logger.exception("Error kicking guest user: %s" % (e,)) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index daf1d3bfb3..77df9185f6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -507,6 +507,7 @@ class FederationHandler(BaseHandler): await self.store.upsert_room_on_join( room_id=room_id, room_version=room_version_obj, + auth_events=auth_chain, ) max_stream_id = await self._persist_auth_tree( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 9f055f00cf..69f8287b2b 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -36,6 +36,7 @@ from synapse import event_auth from synapse.api.constants import ( EventContentFields, EventTypes, + GuestAccess, Membership, RejectedReason, RoomEncryptionAlgorithms, @@ -53,7 +54,6 @@ from synapse.event_auth import auth_types_for_event from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers._base import BaseHandler from synapse.logging.context import ( make_deferred_yieldable, nested_logging_context, @@ -116,7 +116,7 @@ class _NewEventInfo: claimed_auth_event_map: StateMap[EventBase] -class FederationEventHandler(BaseHandler): +class FederationEventHandler: """Handles events that originated from federation. Responsible for handing incoming events and passing them on to the rest @@ -124,26 +124,28 @@ class FederationEventHandler(BaseHandler): """ def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self._store = hs.get_datastore() + self._storage = hs.get_storage() + self._state_store = self._storage.state - self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.state_store = self.storage.state - - self.state_handler = hs.get_state_handler() - self.event_creation_handler = hs.get_event_creation_handler() + self._state_handler = hs.get_state_handler() + self._event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() - self.action_generator = hs.get_action_generator() + self._action_generator = hs.get_action_generator() self._state_resolution_handler = hs.get_state_resolution_handler() + # avoid a circular dependency by deferring execution here + self._get_room_member_handler = hs.get_room_member_handler - self.federation_client = hs.get_federation_client() - self.third_party_event_rules = hs.get_third_party_event_rules() + self._federation_client = hs.get_federation_client() + self._third_party_event_rules = hs.get_third_party_event_rules() + self._notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self._is_mine_id = hs.is_mine_id + self._server_name = hs.hostname self._instance_name = hs.get_instance_name() - self.config = hs.config + self._config = hs.config self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) @@ -175,7 +177,7 @@ class FederationEventHandler(BaseHandler): event_id = pdu.event_id # We reprocess pdus when we have seen them only as outliers - existing = await self.store.get_event( + existing = await self._store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -221,7 +223,7 @@ class FederationEventHandler(BaseHandler): # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. is_in_room = await self._event_auth_handler.check_host_in_room( - room_id, self.server_name + room_id, self._server_name ) if not is_in_room: logger.info( @@ -238,7 +240,7 @@ class FederationEventHandler(BaseHandler): # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if missing_prevs: @@ -272,7 +274,7 @@ class FederationEventHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if not missing_prevs: @@ -361,7 +363,7 @@ class FederationEventHandler(BaseHandler): # the room, so we send it on their behalf. event.internal_metadata.send_on_behalf_of = origin - context = await self.state_handler.compute_event_context(event) + context = await self._state_handler.compute_event_context(event) context = await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError( @@ -375,7 +377,7 @@ class FederationEventHandler(BaseHandler): # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self._third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -404,7 +406,7 @@ class FederationEventHandler(BaseHandler): prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) prev_member_event = None if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self._store.get_event(prev_member_event_id) # Check if the member should be allowed access via membership in a space. await self._event_auth_handler.check_restricted_join_rules( @@ -434,10 +436,10 @@ class FederationEventHandler(BaseHandler): server from invalid events (there is probably no point in trying to re-fetch invalid events from every other HS in the room.) """ - if dest == self.server_name: + if dest == self._server_name: raise SynapseError(400, "Can't backfill from self.") - events = await self.federation_client.backfill( + events = await self._federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -469,12 +471,12 @@ class FederationEventHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) if not prevs - seen: return - latest_list = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self._store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -536,7 +538,7 @@ class FederationEventHandler(BaseHandler): # All that said: Let's try increasing the timeout to 60s and see what happens. try: - missing_events = await self.federation_client.get_missing_events( + missing_events = await self._federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -609,7 +611,7 @@ class FederationEventHandler(BaseHandler): event_id = event.event_id - existing = await self.store.get_event( + existing = await self._store.get_event( event_id, allow_none=True, allow_rejected=True ) if existing: @@ -674,7 +676,7 @@ class FederationEventHandler(BaseHandler): event_id = event.event_id prevs = set(event.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if not missing_prevs: @@ -691,7 +693,7 @@ class FederationEventHandler(BaseHandler): event_map = {event_id: event} try: # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) + ours = await self._state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -720,13 +722,13 @@ class FederationEventHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - room_version = await self.store.get_room_version_id(room_id) + room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, event_map, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore(self._store), ) # We need to give _process_received_pdu the actual state events @@ -734,7 +736,7 @@ class FederationEventHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = await self.store.get_events( + evs = await self._store.get_events( list(state_map.values()), get_prev_content=False, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -774,7 +776,7 @@ class FederationEventHandler(BaseHandler): ( state_event_ids, auth_event_ids, - ) = await self.federation_client.get_room_state_ids( + ) = await self._federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) @@ -788,7 +790,7 @@ class FederationEventHandler(BaseHandler): desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self.store.get_events( + fetched_events = await self._store.get_events( desired_events, allow_rejected=True ) @@ -809,7 +811,7 @@ class FederationEventHandler(BaseHandler): missing_auth_events = set(auth_event_ids) - fetched_events.keys() missing_auth_events.difference_update( - await self.store.have_seen_events(room_id, missing_auth_events) + await self._store.have_seen_events(room_id, missing_auth_events) ) logger.debug("We are also missing %i auth events", len(missing_auth_events)) @@ -822,7 +824,7 @@ class FederationEventHandler(BaseHandler): # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - await self.store.get_events(missing_desired_events, allow_rejected=True) + await self._store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. @@ -901,7 +903,7 @@ class FederationEventHandler(BaseHandler): logger.debug("Processing event: %s", event) try: - context = await self.state_handler.compute_event_context( + context = await self._state_handler.compute_event_context( event, old_state=state ) await self._auth_and_persist_event( @@ -919,7 +921,7 @@ class FederationEventHandler(BaseHandler): device_id = event.content.get("device_id") sender_key = event.content.get("sender_key") - cached_devices = await self.store.get_cached_devices_for_user(event.sender) + cached_devices = await self._store.get_cached_devices_for_user(event.sender) resync = False # Whether we should resync device lists. @@ -995,10 +997,10 @@ class FederationEventHandler(BaseHandler): """ try: - await self.store.mark_remote_user_device_cache_as_stale(sender) + await self._store.mark_remote_user_device_cache_as_stale(sender) # Immediately attempt a resync in the background - if self.config.worker_app: + if self._config.worker_app: await self._user_device_resync(user_id=sender) else: await self._device_list_updater.user_device_resync(sender) @@ -1023,9 +1025,15 @@ class FederationEventHandler(BaseHandler): return # Skip processing a marker event if the room version doesn't - # support it. - room_version = await self.store.get_room_version(marker_event.room_id) - if not room_version.msc2716_historical: + # support it or the event is not from the room creator. + room_version = await self._store.get_room_version(marker_event.room_id) + create_event = await self._store.get_create_event_for_room(marker_event.room_id) + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + if ( + not room_version.msc2716_historical + or not self._config.experimental.msc2716_enabled + or marker_event.sender != room_creator + ): return logger.debug("_handle_marker_event: received %s", marker_event) @@ -1048,7 +1056,7 @@ class FederationEventHandler(BaseHandler): [insertion_event_id], ) - insertion_event = await self.store.get_event( + insertion_event = await self._store.get_event( insertion_event_id, allow_none=True ) if insertion_event is None: @@ -1066,7 +1074,7 @@ class FederationEventHandler(BaseHandler): marker_event, ) - await self.store.insert_insertion_extremity( + await self._store.insert_insertion_extremity( insertion_event_id, marker_event.room_id ) @@ -1088,14 +1096,14 @@ class FederationEventHandler(BaseHandler): Logs a warning if we can't find the given event. """ - room_version = await self.store.get_room_version(room_id) + room_version = await self._store.get_room_version(room_id) event_map: Dict[str, EventBase] = {} async def get_event(event_id: str): with nested_logging_context(event_id): try: - event = await self.federation_client.get_pdu( + event = await self._federation_client.get_pdu( [destination], event_id, room_version, @@ -1131,7 +1139,7 @@ class FederationEventHandler(BaseHandler): for aid in event.auth_event_ids() if aid not in event_map ] - persisted_events = await self.store.get_events( + persisted_events = await self._store.get_events( auth_events, allow_rejected=True, ) @@ -1175,7 +1183,7 @@ class FederationEventHandler(BaseHandler): async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context(event) + res = await self._state_handler.compute_event_context(event) res = await self._check_event_auth( origin, event, @@ -1278,7 +1286,7 @@ class FederationEventHandler(BaseHandler): Returns: The updated context object. """ - room_version = await self.store.get_room_version_id(event.room_id) + room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if claimed_auth_event_map: @@ -1291,7 +1299,7 @@ class FederationEventHandler(BaseHandler): auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events_x = await self.store.get_events(auth_events_ids) + auth_events_x = await self._store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} try: @@ -1321,19 +1329,29 @@ class FederationEventHandler(BaseHandler): if not context.rejected: await self._check_for_soft_fail(event, state, backfilled, origin=origin) - - if event.type == EventTypes.GuestAccess and not context.rejected: - await self.maybe_kick_guest_users(event) + await self._maybe_kick_guest_users(event) # If we are going to send this event over federation we precaclculate # the joined hosts. if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event( + await self._event_creation_handler.cache_joined_hosts_for_event( event, context ) return context + async def _maybe_kick_guest_users(self, event: EventBase) -> None: + if event.type != EventTypes.GuestAccess: + return + + guest_access = event.content.get(EventContentFields.GUEST_ACCESS) + if guest_access == GuestAccess.CAN_JOIN: + return + + current_state_map = await self._state_handler.get_current_state(event.room_id) + current_state = list(current_state_map.values()) + await self._get_room_member_handler().kick_guest_users(current_state) + async def _check_for_soft_fail( self, event: EventBase, @@ -1356,7 +1374,7 @@ class FederationEventHandler(BaseHandler): if backfilled or event.internal_metadata.is_outlier(): return - extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) @@ -1365,7 +1383,7 @@ class FederationEventHandler(BaseHandler): # state at the event, so no point rechecking auth for soft fail. return - room_version = await self.store.get_room_version_id(event.room_id) + room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". @@ -1382,19 +1400,19 @@ class FederationEventHandler(BaseHandler): # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self.state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups( event.room_id, extrem_ids ) state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets.append(state) - current_states = await self.state_handler.resolve_events( + current_states = await self._state_handler.resolve_events( room_version, state_sets, event ) current_state_ids: StateMap[str] = { k: e.event_id for k, e in current_states.items() } else: - current_state_ids = await self.state_handler.get_current_state_ids( + current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids ) @@ -1410,7 +1428,7 @@ class FederationEventHandler(BaseHandler): e for k, e in current_state_ids.items() if k in auth_types ] - auth_events_map = await self.store.get_events(current_state_ids_list) + auth_events_map = await self._store.get_events(current_state_ids_list) current_auth_events = { (e.type, e.state_key): e for e in auth_events_map.values() } @@ -1481,7 +1499,9 @@ class FederationEventHandler(BaseHandler): # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - have_events = await self.store.have_seen_events(event.room_id, missing_auth) + have_events = await self._store.have_seen_events( + event.room_id, missing_auth + ) logger.debug("Events %s are in the store", have_events) missing_auth.difference_update(have_events) @@ -1490,7 +1510,7 @@ class FederationEventHandler(BaseHandler): logger.info("auth_events contains unknown events: %s", missing_auth) try: try: - remote_auth_chain = await self.federation_client.get_event_auth( + remote_auth_chain = await self._federation_client.get_event_auth( origin, event.room_id, event.event_id ) except RequestSendFailed as e1: @@ -1499,7 +1519,7 @@ class FederationEventHandler(BaseHandler): logger.info("Failed to get event auth from remote: %s", e1) return context, auth_events - seen_remotes = await self.store.have_seen_events( + seen_remotes = await self._store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] ) @@ -1525,7 +1545,7 @@ class FederationEventHandler(BaseHandler): e.event_id, ) missing_auth_event_context = ( - await self.state_handler.compute_event_context(e) + await self._state_handler.compute_event_context(e) ) await self._auth_and_persist_event( origin, @@ -1566,7 +1586,7 @@ class FederationEventHandler(BaseHandler): # XXX: currently this checks for redactions but I'm not convinced that is # necessary? - different_events = await self.store.get_events_as_list(different_auth) + different_events = await self._store.get_events_as_list(different_auth) for d in different_events: if d.room_id != event.room_id: @@ -1592,8 +1612,8 @@ class FederationEventHandler(BaseHandler): remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) remote_state = remote_auth_events.values() - room_version = await self.store.get_room_version_id(event.room_id) - new_state = await self.state_handler.resolve_events( + room_version = await self._store.get_room_version_id(event.room_id) + new_state = await self._state_handler.resolve_events( room_version, (local_state, remote_state), event ) @@ -1651,7 +1671,7 @@ class FederationEventHandler(BaseHandler): # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self.state_store.store_state_group( + state_group = await self._state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -1683,9 +1703,9 @@ class FederationEventHandler(BaseHandler): not event.internal_metadata.is_outlier() and not backfilled and not context.rejected - and (await self.store.get_min_depth(event.room_id)) <= event.depth + and (await self._store.get_min_depth(event.room_id)) <= event.depth ): - await self.action_generator.handle_push_actions_for_event( + await self._action_generator.handle_push_actions_for_event( event, context ) @@ -1694,7 +1714,7 @@ class FederationEventHandler(BaseHandler): ) except Exception: run_in_background( - self.store.remove_push_actions_from_staging, event.event_id + self._store.remove_push_actions_from_staging, event.event_id ) raise @@ -1719,27 +1739,27 @@ class FederationEventHandler(BaseHandler): The stream ID after which all events have been persisted. """ if not event_and_contexts: - return self.store.get_current_events_token() + return self._store.get_current_events_token() - instance = self.config.worker.events_shard_config.get_instance(room_id) + instance = self._config.worker.events_shard_config.get_instance(room_id) if instance != self._instance_name: # Limit the number of events sent over replication. We choose 200 # here as that is what we default to in `max_request_body_size(..)` for batch in batch_iter(event_and_contexts, 200): result = await self._send_events( instance_name=instance, - store=self.store, + store=self._store, room_id=room_id, event_and_contexts=batch, backfilled=backfilled, ) return result["max_stream_id"] else: - assert self.storage.persistence + assert self._storage.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self.storage.persistence.persist_events( + events, max_stream_token = await self._storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -1773,7 +1793,7 @@ class FederationEventHandler(BaseHandler): # users if event.internal_metadata.is_outlier(): if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): + if not self._is_mine_id(target_user_id): return target_user = UserID.from_string(target_user_id) @@ -1787,7 +1807,7 @@ class FederationEventHandler(BaseHandler): event_pos = PersistedEventPosition( self._instance_name, event.internal_metadata.stream_ordering ) - self.notifier.on_new_room_event( + self._notifier.on_new_room_event( event, event_pos, max_stream_token, extra_users=extra_users ) @@ -1822,4 +1842,4 @@ class FederationEventHandler(BaseHandler): raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def get_min_depth_for_context(self, context: str) -> int: - return await self.store.get_min_depth(context) + return await self._store.get_min_depth(context) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 101a29c6d3..bf0fef1510 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -27,6 +27,7 @@ from synapse import event_auth from synapse.api.constants import ( EventContentFields, EventTypes, + GuestAccess, Membership, RelationTypes, UserTypes, @@ -426,7 +427,7 @@ class EventCreationHandler: self.send_event = ReplicationSendEventRestServlet.make_client(hs) - # This is only used to get at ratelimit function, and maybe_kick_guest_users + # This is only used to get at ratelimit function self.base_handler = BaseHandler(hs) # We arbitrarily limit concurrent event creation for a room to 5. @@ -1306,7 +1307,7 @@ class EventCreationHandler: requester, is_admin_redaction=is_admin_redaction ) - await self.base_handler.maybe_kick_guest_users(event, context) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Validate a newly added alias or newly added alt_aliases. @@ -1393,6 +1394,9 @@ class EventCreationHandler: allow_none=True, ) + room_version = await self.store.get_room_version_id(event.room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + # we can make some additional checks now if we have the original event. if original_event: if original_event.type == EventTypes.Create: @@ -1404,6 +1408,28 @@ class EventCreationHandler: if original_event.type == EventTypes.ServerACL: raise AuthError(403, "Redacting server ACL events is not permitted") + # Add a little safety stop-gap to prevent people from trying to + # redact MSC2716 related events when they're in a room version + # which does not support it yet. We allow people to use MSC2716 + # events in existing room versions but only from the room + # creator since it does not require any changes to the auth + # rules and in effect, the redaction algorithm . In the + # supported room version, we add the `historical` power level to + # auth the MSC2716 related events and adjust the redaction + # algorthim to keep the `historical` field around (redacting an + # event should only strip fields which don't affect the + # structural protocol level). + is_msc2716_event = ( + original_event.type == EventTypes.MSC2716_INSERTION + or original_event.type == EventTypes.MSC2716_CHUNK + or original_event.type == EventTypes.MSC2716_MARKER + ) + if not room_version_obj.msc2716_historical and is_msc2716_event: + raise AuthError( + 403, + "Redacting MSC2716 events is not supported in this room version", + ) + prev_state_ids = await context.get_prev_state_ids() auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True @@ -1411,9 +1437,6 @@ class EventCreationHandler: auth_events_map = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()} - room_version = await self.store.get_room_version_id(event.room_id) - room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if event_auth.check_redaction( room_version_obj, event, auth_events=auth_events ): @@ -1471,6 +1494,28 @@ class EventCreationHandler: return event + async def _maybe_kick_guest_users( + self, event: EventBase, context: EventContext + ) -> None: + if event.type != EventTypes.GuestAccess: + return + + guest_access = event.content.get(EventContentFields.GUEST_ACCESS) + if guest_access == GuestAccess.CAN_JOIN: + return + + current_state_ids = await context.get_current_state_ids() + + # since this is a client-generated event, it cannot be an outlier and we must + # therefore have the state ids. + assert current_state_ids is not None + current_state_dict = await self.store.get_events( + list(current_state_ids.values()) + ) + current_state = list(current_state_dict.values()) + logger.info("maybe_kick_guest_users %r", current_state) + await self.hs.get_room_member_handler().kick_guest_users(current_state) + async def _bump_active_time(self, user: UserID) -> None: try: presence = self.hs.get_presence_handler() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b33fe09f77..0235fd09b4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -25,7 +25,9 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple from synapse.api.constants import ( + EventContentFields, EventTypes, + GuestAccess, HistoryVisibility, JoinRules, Membership, @@ -909,7 +911,12 @@ class RoomCreationHandler(BaseHandler): ) return last_stream_id - config = self._presets_dict[preset_config] + try: + config = self._presets_dict[preset_config] + except KeyError: + raise SynapseError( + 400, f"'{preset_config}' is not a valid preset", errcode=Codes.BAD_JSON + ) creation_content.update({"creator": creator_id}) await send(etype=EventTypes.Create, content=creation_content) @@ -988,7 +995,8 @@ class RoomCreationHandler(BaseHandler): if config["guest_can_join"]: if (EventTypes.GuestAccess, "") not in initial_state: last_sent_stream_id = await send( - etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} + etype=EventTypes.GuestAccess, + content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, ) for (etype, state_key), content in initial_state.items(): diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 6d433fad41..92bb75c848 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -19,7 +19,13 @@ from typing import TYPE_CHECKING, Optional, Tuple import msgpack from unpaddedbase64 import decode_base64, encode_base64 -from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules +from synapse.api.constants import ( + EventContentFields, + EventTypes, + GuestAccess, + HistoryVisibility, + JoinRules, +) from synapse.api.errors import ( Codes, HttpResponseException, @@ -336,8 +342,8 @@ class RoomListHandler(BaseHandler): guest_event = current_state.get((EventTypes.GuestAccess, "")) guest = None if guest_event: - guest = guest_event.content.get("guest_access", None) - result["guest_can_join"] = guest == "can_join" + guest = guest_event.content.get(EventContentFields.GUEST_ACCESS) + result["guest_can_join"] = guest == GuestAccess.CAN_JOIN avatar_event = current_state.get(("m.room.avatar", "")) if avatar_event: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 401b84aad1..4390201641 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -23,6 +23,7 @@ from synapse.api.constants import ( AccountDataTypes, EventContentFields, EventTypes, + GuestAccess, Membership, ) from synapse.api.errors import ( @@ -44,6 +45,7 @@ from synapse.types import ( RoomID, StateMap, UserID, + create_requester, get_domain_from_id, ) from synapse.util.async_helpers import Linearizer @@ -70,6 +72,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config + self._server_name = hs.hostname self.federation_handler = hs.get_federation_handler() self.directory_handler = hs.get_directory_handler() @@ -115,9 +118,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) - # This is only used to get at ratelimit function, and - # maybe_kick_guest_users. It's fine there are multiple of these as - # it doesn't store state. + # This is only used to get at the ratelimit function. It's fine there are + # multiple of these as it doesn't store state. self.base_handler = BaseHandler(hs) @abc.abstractmethod @@ -1095,10 +1097,62 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): return bool( guest_access and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" + and guest_access.content.get(EventContentFields.GUEST_ACCESS) + == GuestAccess.CAN_JOIN ) + async def kick_guest_users(self, current_state: Iterable[EventBase]) -> None: + """Kick any local guest users from the room. + + This is called when the room state changes from guests allowed to not-allowed. + + Params: + current_state: the current state of the room. We will iterate this to look + for guest users to kick. + """ + for member_event in current_state: + try: + if member_event.type != EventTypes.Member: + continue + + if not self.hs.is_mine_id(member_event.state_key): + continue + + if member_event.content["membership"] not in { + Membership.JOIN, + Membership.INVITE, + }: + continue + + if ( + "kind" not in member_event.content + or member_event.content["kind"] != "guest" + ): + continue + + # We make the user choose to leave, rather than have the + # event-sender kick them. This is partially because we don't + # need to worry about power levels, and partially because guest + # users are a concept which doesn't hugely work over federation, + # and having homeservers have their own users leave keeps more + # of that decision-making and control local to the guest-having + # homeserver. + target_user = UserID.from_string(member_event.state_key) + requester = create_requester( + target_user, is_guest=True, authenticated_entity=self._server_name + ) + handler = self.hs.get_room_member_handler() + await handler.update_membership( + requester, + target_user, + member_event.room_id, + "leave", + ratelimit=False, + require_consent=False, + ) + except Exception as e: + logger.exception("Error kicking guest user: %s" % (e,)) + async def lookup_room_alias( self, room_alias: RoomAlias ) -> Tuple[RoomID, List[str]]: @@ -1352,7 +1406,6 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor = hs.get_distributor() self.distributor.declare("user_left_room") - self._server_name = hs.hostname async def _is_remote_room_too_complex( self, room_id: str, remote_room_hosts: List[str] diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 906985c754..781da9e811 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -28,9 +28,15 @@ from synapse.api.constants import ( Membership, RoomTypes, ) -from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError +from synapse.api.errors import ( + AuthError, + Codes, + NotFoundError, + StoreError, + SynapseError, + UnsupportedRoomVersionError, +) from synapse.events import EventBase -from synapse.events.utils import format_event_for_client_v2 from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache @@ -82,7 +88,6 @@ class RoomSummaryHandler: _PAGINATION_SESSION_VALIDITY_PERIOD_MS = 5 * 60 * 1000 def __init__(self, hs: "HomeServer"): - self._clock = hs.get_clock() self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastore() self._event_serializer = hs.get_event_client_serializer() @@ -641,18 +646,18 @@ class RoomSummaryHandler: if max_children is None or max_children > MAX_ROOMS_PER_SPACE: max_children = MAX_ROOMS_PER_SPACE - now = self._clock.time_msec() - events_result: List[JsonDict] = [] - for edge_event in itertools.islice(child_events, max_children): - events_result.append( - await self._event_serializer.serialize_event( - edge_event, - time_now=now, - event_format=format_event_for_client_v2, - ) - ) - - return _RoomEntry(room_id, room_entry, events_result) + stripped_events: List[JsonDict] = [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "room_id": e.room_id, + "sender": e.sender, + "origin_server_ts": e.origin_server_ts, + } + for e in itertools.islice(child_events, max_children) + ] + return _RoomEntry(room_id, room_entry, stripped_events) async def _summarize_remote_room( self, @@ -814,7 +819,12 @@ class RoomSummaryHandler: logger.info("room %s is unknown, omitting from summary", room_id) return False - room_version = await self._store.get_room_version(room_id) + try: + room_version = await self._store.get_room_version(room_id) + except UnsupportedRoomVersionError: + # If a room with an unsupported room version is encountered, ignore + # it to avoid breaking the entire summary response. + return False # Include the room if it has join rules of public or knock. join_rules_event_id = state_ids.get((EventTypes.JoinRules, "")) @@ -1139,25 +1149,26 @@ def _is_suggested_child_event(edge_event: EventBase) -> bool: _INVALID_ORDER_CHARS_RE = re.compile(r"[^\x20-\x7E]") -def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], str]: +def _child_events_comparison_key( + child: EventBase, +) -> Tuple[bool, Optional[str], int, str]: """ Generate a value for comparing two child events for ordering. - The rules for ordering are supposed to be: + The rules for ordering are: 1. The 'order' key, if it is valid. - 2. The 'origin_server_ts' of the 'm.room.create' event. + 2. The 'origin_server_ts' of the 'm.space.child' event. 3. The 'room_id'. - But we skip step 2 since we may not have any state from the room. - Args: child: The event for generating a comparison key. Returns: The comparison key as a tuple of: False if the ordering is valid. - The ordering field. + The 'order' field or None if it is not given or invalid. + The 'origin_server_ts' field. The room ID. """ order = child.content.get("order") @@ -1168,4 +1179,4 @@ def _child_events_comparison_key(child: EventBase) -> Tuple[bool, Optional[str], order = None # Items without an order come last. - return (order is None, order, child.room_id) + return (order is None, order, child.origin_server_ts, child.room_id) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 3fd89af2a4..3a4c41c9ff 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple from typing_extensions import Counter as CounterType -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict @@ -273,7 +273,9 @@ class StatsHandler: elif typ == EventTypes.CanonicalAlias: room_state["canonical_alias"] = event_content.get("alias") elif typ == EventTypes.GuestAccess: - room_state["guest_access"] = event_content.get("guest_access") + room_state["guest_access"] = event_content.get( + EventContentFields.GUEST_ACCESS + ) for room_id, state in room_to_state_updates.items(): logger.debug("Updating room_stats_state for %s: %s", room_id, state) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 86c3c7f0df..e017b28cd2 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -505,10 +505,13 @@ class SyncHandler: else: limited = False + log_kv({"limited": limited}) + if potential_recents: recents = sync_config.filter_collection.filter_room_timeline( potential_recents ) + log_kv({"recents_after_sync_filtering": len(recents)}) # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to @@ -526,6 +529,7 @@ class SyncHandler: recents, always_include_ids=current_state_ids, ) + log_kv({"recents_after_visibility_filtering": len(recents)}) else: recents = [] @@ -566,10 +570,15 @@ class SyncHandler: events, end_key = await self.store.get_recent_events_for_room( room_id, limit=load_limit + 1, end_token=end_key ) + + log_kv({"loaded_recents": len(events)}) + loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) + log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)}) + # We check if there are any state events, if there are then we pass # all current state events to the filter_events function. This is to # ensure that we always include current state in the timeline @@ -586,6 +595,9 @@ class SyncHandler: loaded_recents, always_include_ids=current_state_ids, ) + + log_kv({"loaded_recents_after_client_filtering": len(loaded_recents)}) + loaded_recents.extend(recents) recents = loaded_recents @@ -1116,6 +1128,8 @@ class SyncHandler: logger.debug("Fetching group data") await self._generate_sync_entry_for_groups(sync_result_builder) + num_events = 0 + # debug for https://github.com/matrix-org/synapse/issues/4422 for joined_room in sync_result_builder.joined: room_id = joined_room.room_id @@ -1123,6 +1137,14 @@ class SyncHandler: issue4422_logger.debug( "Sync result for newly joined room %s: %r", room_id, joined_room ) + num_events += len(joined_room.timeline.events) + + log_kv( + { + "joined_rooms_in_result": len(sync_result_builder.joined), + "events_in_result": num_events, + } + ) logger.debug("Sync response calculation complete") return SyncResult( @@ -1467,6 +1489,7 @@ class SyncHandler: if not sync_result_builder.full_state: if since_token and not ephemeral_by_room and not account_data_by_room: have_changed = await self._have_rooms_changed(sync_result_builder) + log_kv({"rooms_have_changed": have_changed}) if not have_changed: tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key @@ -1501,25 +1524,30 @@ class SyncHandler: tags_by_room = await self.store.get_tags_for_user(user_id) + log_kv({"rooms_changed": len(room_changes.room_entries)}) + room_entries = room_changes.room_entries invited = room_changes.invited knocked = room_changes.knocked newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms - async def handle_room_entries(room_entry): - logger.debug("Generating room entry for %s", room_entry.room_id) - res = await self._generate_room_entry( - sync_result_builder, - ignored_users, - room_entry, - ephemeral=ephemeral_by_room.get(room_entry.room_id, []), - tags=tags_by_room.get(room_entry.room_id), - account_data=account_data_by_room.get(room_entry.room_id, {}), - always_include=sync_result_builder.full_state, - ) - logger.debug("Generated room entry for %s", room_entry.room_id) - return res + async def handle_room_entries(room_entry: "RoomSyncResultBuilder"): + with start_active_span("generate_room_entry"): + set_tag("room_id", room_entry.room_id) + log_kv({"events": len(room_entry.events or [])}) + logger.debug("Generating room entry for %s", room_entry.room_id) + res = await self._generate_room_entry( + sync_result_builder, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builder.full_state, + ) + logger.debug("Generated room entry for %s", room_entry.room_id) + return res await concurrently_execute(handle_room_entries, room_entries, 10) @@ -1932,6 +1960,12 @@ class SyncHandler: room_id = room_builder.room_id since_token = room_builder.since_token upto_token = room_builder.upto_token + log_kv( + { + "since_token": since_token, + "upto_token": upto_token, + } + ) batch = await self._load_filtered_recents( room_id, @@ -1941,6 +1975,13 @@ class SyncHandler: potential_recents=events, newly_joined_room=newly_joined, ) + log_kv( + { + "batch_events": len(batch.events), + "prev_batch": batch.prev_batch, + "batch_limited": batch.limited, + } + ) # Note: `batch` can be both empty and limited here in the case where # `_load_filtered_recents` can't find any events the user should see diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index a12fa30bfd..91ba93372c 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -572,6 +572,25 @@ def parse_string_from_args( return strings[0] +@overload +def parse_json_value_from_request(request: Request) -> JsonDict: + ... + + +@overload +def parse_json_value_from_request( + request: Request, allow_empty_body: Literal[False] +) -> JsonDict: + ... + + +@overload +def parse_json_value_from_request( + request: Request, allow_empty_body: bool = False +) -> Optional[JsonDict]: + ... + + def parse_json_value_from_request( request: Request, allow_empty_body: bool = False ) -> Optional[JsonDict]: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 941fb238b7..b0834720ad 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -258,7 +258,7 @@ class Mailer: # actually sort our so-called rooms_in_order list, most recent room first rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) - rooms = [] + rooms: List[Dict[str, Any]] = [] for r in rooms_in_order: roomvars = await self._get_room_vars( @@ -362,6 +362,7 @@ class Mailer: "notifs": [], "invite": is_invite, "link": self._make_room_link(room_id), + "avatar_url": await self._get_room_avatar(room_state_ids), } if not is_invite: @@ -393,6 +394,27 @@ class Mailer: return room_vars + async def _get_room_avatar( + self, + room_state_ids: StateMap[str], + ) -> Optional[str]: + """ + Retrieve the avatar url for this room---if it exists. + + Args: + room_state_ids: The event IDs of the current room state. + + Returns: + room's avatar url if it's present and a string; otherwise None. + """ + event_id = room_state_ids.get((EventTypes.RoomAvatar, "")) + if event_id: + ev = await self.store.get_event(event_id) + url = ev.content.get("url") + if isinstance(url, str): + return url + return None + async def _get_notif_vars( self, notif: Dict[str, Any], diff --git a/synapse/res/providers.json b/synapse/res/providers.json new file mode 100644 index 0000000000..f1838f9559 --- /dev/null +++ b/synapse/res/providers.json @@ -0,0 +1,17 @@ +[ + { + "provider_name": "Twitter", + "provider_url": "http://www.twitter.com/", + "endpoints": [ + { + "schemes": [ + "https://twitter.com/*/status/*", + "https://*.twitter.com/*/status/*", + "https://twitter.com/*/moments/*", + "https://*.twitter.com/*/moments/*" + ], + "url": "https://publish.twitter.com/oembed" + } + ] + } +] \ No newline at end of file diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py index 42201afc86..f5a38c2670 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.api.errors import NotFoundError, SynapseError @@ -101,7 +101,9 @@ class SendServerNoticeServlet(RestServlet): return 200, {"event_id": event.event_id} - def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]: + def on_PUT( + self, request: SynapseRequest, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/_base.py b/synapse/rest/client/_base.py index 0443f4571c..a0971ce994 100644 --- a/synapse/rest/client/_base.py +++ b/synapse/rest/client/_base.py @@ -16,7 +16,7 @@ """ import logging import re -from typing import Iterable, Pattern +from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.urls import CLIENT_API_PREFIX @@ -76,7 +76,10 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) ) -def interactive_auth_handler(orig): +C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]]) + + +def interactive_auth_handler(orig: C) -> C: """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors Takes a on_POST method which returns an Awaitable (errcode, body) response @@ -91,10 +94,10 @@ def interactive_auth_handler(orig): await self.auth_handler.check_auth """ - async def wrapped(*args, **kwargs): + async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]: try: return await orig(*args, **kwargs) except InteractiveAuthIncompleteError as e: return 401, e.result - return wrapped + return cast(C, wrapped) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fb5ad2906e..aefaaa8ae8 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -16,9 +16,11 @@ import logging import random from http import HTTPStatus -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple from urllib.parse import urlparse +from twisted.web.server import Request + from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, @@ -28,15 +30,17 @@ from synapse.api.errors import ( ) from synapse.config.emailconfig import ThreepidBehaviour from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed, validate_email @@ -68,7 +72,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): template_text=self.config.email_password_reset_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -159,7 +163,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -169,7 +173,7 @@ class PasswordRestServlet(RestServlet): self._set_password_handler = hs.get_set_password_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -190,6 +194,7 @@ class PasswordRestServlet(RestServlet): # # In the second case, we require a password to confirm their identity. + requester = None if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) try: @@ -206,16 +211,15 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise user_id = requester.user.to_string() else: - requester = None try: result, params, session_id = await self.auth_handler.check_ui_auth( [[LoginType.EMAIL_IDENTITY]], @@ -230,11 +234,11 @@ class PasswordRestServlet(RestServlet): # If a password is available now, hash the provided password and # store it for later. if new_password: - password_hash = await self.auth_handler.hash(new_password) + new_password_hash = await self.auth_handler.hash(new_password) await self.auth_handler.set_session_data( e.session_id, UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + new_password_hash, ) raise @@ -264,7 +268,7 @@ class PasswordRestServlet(RestServlet): # If we have a password in this request, prefer it. Otherwise, use the # password hash from an earlier request. if new_password: - password_hash = await self.auth_handler.hash(new_password) + password_hash: Optional[str] = await self.auth_handler.hash(new_password) elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None @@ -288,7 +292,7 @@ class PasswordRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -296,7 +300,7 @@ class DeactivateAccountRestServlet(RestServlet): self._deactivate_account_handler = hs.get_deactivate_account_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) erase = body.get("erase", False) if not isinstance(erase, bool): @@ -338,7 +342,7 @@ class DeactivateAccountRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config @@ -353,7 +357,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): template_text=self.config.email_add_threepid_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -449,7 +453,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( body, ["client_secret", "country", "phone_number", "send_attempt"] @@ -525,11 +529,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): "/add_threepid/email/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() @@ -539,7 +539,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config.email_add_threepid_template_failure_html ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> None: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -596,18 +596,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "/add_threepid/msisdn/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: if not self.config.account_threepid_delegate_msisdn: raise SynapseError( 400, @@ -632,7 +628,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -640,14 +636,14 @@ class ThreepidRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.datastore = self.hs.get_datastore() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) threepids = await self.datastore.user_get_threepids(requester.user.to_string()) return 200, {"threepids": threepids} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -688,7 +684,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -696,7 +692,7 @@ class ThreepidAddRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -738,13 +734,13 @@ class ThreepidAddRestServlet(RestServlet): class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) @@ -767,14 +763,14 @@ class ThreepidBindRestServlet(RestServlet): class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.datastore = self.hs.get_datastore() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """Unbind the given 3pid from a specific identity server, or identity servers that are known to have this 3pid bound """ @@ -798,13 +794,13 @@ class ThreepidUnbindRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN @@ -835,7 +831,7 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} -def assert_valid_next_link(hs: "HomeServer", next_link: str): +def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: """ Raises a SynapseError if a given next_link value is invalid @@ -877,11 +873,11 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str): class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) response = {"user_id": requester.user.to_string()} @@ -894,7 +890,7 @@ class WhoamiRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py index 7517e9304e..d1badbdf3b 100644 --- a/synapse/rest/client/account_data.py +++ b/synapse/rest/client/account_data.py @@ -13,12 +13,19 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import AuthError, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet): "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, account_data_type): + async def on_PUT( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet): return 200, {} - async def on_GET(self, request, user_id, account_data_type): + async def on_GET( + self, request: SynapseRequest, user_id: str, account_data_type: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet): "/account_data/(?P<account_data_type>[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() - async def on_PUT(self, request, user_id, room_id, account_data_type): + async def on_PUT( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet): return 200, {} - async def on_GET(self, request, user_id, room_id, account_data_type): + async def on_GET( + self, + request: SynapseRequest, + user_id: str, + room_id: str, + account_data_type: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot get account data for other users.") @@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet): return 200, event -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: AccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py index c3667ff8aa..a7e9aa3e9b 100644 --- a/synapse/rest/client/groups.py +++ b/synapse/rest/client/groups.py @@ -15,7 +15,7 @@ import logging from functools import wraps -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple from twisted.web.server import Request @@ -43,14 +43,18 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _validate_group_id(f): +def _validate_group_id( + f: Callable[..., Awaitable[Tuple[int, JsonDict]]] +) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]: """Wrapper to validate the form of the group ID. Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. """ @wraps(f) - def wrapper(self, request: Request, group_id: str, *args, **kwargs): + def wrapper( + self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any + ) -> Awaitable[Tuple[int, JsonDict]]: if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) @@ -156,7 +160,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): group_id: str, category_id: Optional[str], room_id: str, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -188,7 +192,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -451,7 +455,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): @_validate_group_id async def on_DELETE( self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -674,7 +678,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): @_validate_group_id async def on_PUT( self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -706,7 +710,7 @@ class GroupAdminUsersInviteServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -738,7 +742,7 @@ class GroupAdminUsersKickServlet(RestServlet): @_validate_group_id async def on_PUT( - self, request: SynapseRequest, group_id, user_id + self, request: SynapseRequest, group_id: str, user_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() diff --git a/synapse/rest/client/knock.py b/synapse/rest/client/knock.py index 68fb08d0ba..0152a0c66a 100644 --- a/synapse/rest/client/knock.py +++ b/synapse/rest/client/knock.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from twisted.web.server import Request @@ -96,7 +96,9 @@ class KnockRoomAliasServlet(RestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + def on_PUT( + self, request: Request, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index 702b351d18..fb3211bf3a 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -12,22 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Union + +import attr + from synapse.api.errors import ( NotFoundError, StoreError, SynapseError, UnrecognizedRequestError, ) +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_value_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class RuleSpec: + scope: str + template: str + rule_id: str + attr: Optional[str] class PushRuleRestServlet(RestServlet): @@ -36,7 +54,7 @@ class PushRuleRestServlet(RestServlet): "Unrecognised request: You probably wanted a trailing slash" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -45,7 +63,7 @@ class PushRuleRestServlet(RestServlet): self._users_new_default_push_rules = hs.config.users_new_default_push_rules - async def on_PUT(self, request, path): + async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -57,25 +75,25 @@ class PushRuleRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: + if "/" in spec.rule_id or "\\" in spec.rule_id: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if "attr" in spec: + if spec.attr: await self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) return 200, {} - if spec["rule_id"].startswith("."): + if spec.rule_id.startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec["template"], spec["rule_id"], content + spec.template, spec.rule_id, content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -106,7 +124,9 @@ class PushRuleRestServlet(RestServlet): return 200, {} - async def on_DELETE(self, request, path): + async def on_DELETE( + self, request: SynapseRequest, path: str + ) -> Tuple[int, JsonDict]: if self._is_worker: raise Exception("Cannot handle DELETE /push_rules on worker") @@ -127,7 +147,7 @@ class PushRuleRestServlet(RestServlet): else: raise - async def on_GET(self, request, path): + async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -138,40 +158,42 @@ class PushRuleRestServlet(RestServlet): rules = format_push_rules_for_user(requester.user, rules) - path = path.split("/")[1:] + path_parts = path.split("/")[1:] - if path == []: + if path_parts == []: # we're a reference impl: pedantry is our job. raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == "": + if path_parts[0] == "": return 200, rules - elif path[0] == "global": - result = _filter_ruleset_with_path(rules["global"], path[1:]) + elif path_parts[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path_parts[1:]) return 200, result else: raise UnrecognizedRequestError() - def notify_user(self, user_id): + def notify_user(self, user_id: str) -> None: stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - async def set_rule_attr(self, user_id, spec, val): - if spec["attr"] not in ("enabled", "actions"): + async def set_rule_attr( + self, user_id: str, spec: RuleSpec, val: Union[bool, JsonDict] + ) -> None: + if spec.attr not in ("enabled", "actions"): # for the sake of potential future expansion, shouldn't report # 404 in the case of an unknown request so check it corresponds to # a known attribute first. raise UnrecognizedRequestError() namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) - if spec["attr"] == "enabled": + if spec.attr == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -179,14 +201,18 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - return await self.store.set_push_rule_enabled( + await self.store.set_push_rule_enabled( user_id, namespaced_rule_id, val, is_default_rule ) - elif spec["attr"] == "actions": + elif spec.attr == "actions": + if not isinstance(val, dict): + raise SynapseError(400, "Value must be a dict") actions = val.get("actions") + if not isinstance(actions, list): + raise SynapseError(400, "Value for 'actions' must be dict") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec["rule_id"] + rule_id = spec.rule_id is_default_rule = rule_id.startswith(".") if is_default_rule: if user_id in self._users_new_default_push_rules: @@ -196,22 +222,21 @@ class PushRuleRestServlet(RestServlet): if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return await self.store.set_push_rule_actions( + await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: raise UnrecognizedRequestError() -def _rule_spec_from_path(path): +def _rule_spec_from_path(path: Sequence[str]) -> RuleSpec: """Turn a sequence of path components into a rule spec Args: - path (sequence[unicode]): the URL path components. + path: the URL path components. Returns: - dict: rule spec dict, containing scope/template/rule_id entries, - and possibly attr. + rule spec, containing scope/template/rule_id entries, and possibly attr. Raises: UnrecognizedRequestError if the path components cannot be parsed. @@ -237,17 +262,18 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = {"scope": scope, "template": template, "rule_id": rule_id} - path = path[1:] + attr = None if len(path) > 0 and len(path[0]) > 0: - spec["attr"] = path[0] + attr = path[0] - return spec + return RuleSpec(scope, template, rule_id, attr) -def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): +def _rule_tuple_from_request_object( + rule_template: str, rule_id: str, req_obj: JsonDict +) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]: if rule_template in ["override", "underride"]: if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") @@ -277,7 +303,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): return conditions, actions -def _check_actions(actions): +def _check_actions(actions: List[Union[str, JsonDict]]) -> None: if not isinstance(actions, list): raise InvalidRuleException("No actions found") @@ -290,7 +316,7 @@ def _check_actions(actions): raise InvalidRuleException("Unrecognised action") -def _filter_ruleset_with_path(ruleset, path): +def _filter_ruleset_with_path(ruleset: JsonDict, path: List[str]) -> JsonDict: if path == []: raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR @@ -315,7 +341,7 @@ def _filter_ruleset_with_path(ruleset, path): if r["rule_id"] == rule_id: the_rule = r if the_rule is None: - raise NotFoundError + raise NotFoundError() path = path[1:] if len(path) == 0: @@ -330,25 +356,25 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError() -def _priority_class_from_spec(spec): - if spec["template"] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec["template"])) - pc = PRIORITY_CLASS_MAP[spec["template"]] +def _priority_class_from_spec(spec: RuleSpec) -> int: + if spec.template not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec.template)) + pc = PRIORITY_CLASS_MAP[spec.template] return pc -def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec["rule_id"]) +def _namespaced_rule_id_from_spec(spec: RuleSpec) -> str: + return _namespaced_rule_id(spec, spec.rule_id) -def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec["template"], rule_id) +def _namespaced_rule_id(spec: RuleSpec, rule_id: str) -> str: + return "global/%s/%s" % (spec.template, rule_id) class InvalidRuleException(Exception): pass -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index d9ab836cd8..9770413c61 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -13,13 +13,20 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet): "/(?P<event_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.presence_handler = hs.get_presence_handler() - async def on_POST(self, request, room_id, receipt_type, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if receipt_type != "m.read": @@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 7b5f49d635..8f3dd2a101 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -14,7 +14,9 @@ # limitations under the License. import logging import random -from typing import List, Union +from typing import TYPE_CHECKING, List, Optional, Tuple + +from twisted.web.server import Request import synapse import synapse.api.auth @@ -29,15 +31,13 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.config import ConfigError -from synapse.config.captcha import CaptchaConfig -from synapse.config.consent import ConsentConfig from synapse.config.emailconfig import ThreepidBehaviour +from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig -from synapse.config.registration import RegistrationConfig from synapse.config.server import is_threepid_reserved from synapse.handlers.auth import AuthHandler from synapse.handlers.ui_auth import UIAuthSessionDataConstants -from synapse.http.server import finish_request, respond_with_html +from synapse.http.server import HttpServer, finish_request, respond_with_html from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -45,6 +45,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer from synapse.types import JsonDict @@ -59,17 +60,16 @@ from synapse.util.threepids import ( from ._base import client_patterns, interactive_auth_handler +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class EmailRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/email/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() @@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): template_text=self.config.email_registration_template_text, ) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.local_threepid_handling_disabled_due_to_email_config: logger.warning( @@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/register/msisdn/requestToken$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.identity_handler = hs.get_identity_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) assert_params_in_dict( @@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet): "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet): self.config.email_registration_template_failure_html ) - async def on_GET(self, request, medium): + async def on_GET(self, request: Request, medium: str) -> None: if medium != "email": raise SynapseError( 400, "This medium is currently not supported for registration" @@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet): PATTERNS = client_patterns("/register/available") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() @@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ), ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.hs.config.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN @@ -387,11 +375,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): unstable=True, ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -402,7 +386,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, ) - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) if not self.hs.config.enable_registration: @@ -419,11 +403,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): class RegisterRestServlet(RestServlet): PATTERNS = client_patterns("/register$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs @@ -445,23 +425,21 @@ class RegisterRestServlet(RestServlet): ) @interactive_auth_handler - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: body = parse_json_object_from_request(request) client_addr = request.getClientIP() await self.ratelimiter.ratelimit(None, client_addr, update=False) - kind = b"user" - if b"kind" in request.args: - kind = request.args[b"kind"][0] + kind = parse_string(request, "kind", default="user") - if kind == b"guest": + if kind == "guest": ret = await self._do_guest_registration(body, address=client_addr) return ret - elif kind != b"user": + elif kind != "user": raise UnrecognizedRequestError( - "Do not understand membership kind: %s" % (kind.decode("utf8"),) + f"Do not understand membership kind: {kind}", ) if self._msc2918_enabled: @@ -748,8 +726,12 @@ class RegisterRestServlet(RestServlet): return 200, return_dict async def _do_appservice_registration( - self, username, as_token, body, should_issue_refresh_token: bool = False - ): + self, + username: str, + as_token: str, + body: JsonDict, + should_issue_refresh_token: bool = False, + ) -> JsonDict: user_id = await self.registration_handler.appservice_register( username, as_token ) @@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet): params: JsonDict, is_appservice_ghost: bool = False, should_issue_refresh_token: bool = False, - ): + ) -> JsonDict: """Complete registration of newly-registered user Allocates device_id if one was not given; also creates access_token. @@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet): return result - async def _do_guest_registration(self, params, address=None): + async def _do_guest_registration( + self, params: JsonDict, address: Optional[str] = None + ) -> Tuple[int, JsonDict]: if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( @@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet): def _calculate_registration_flows( - # technically `config` has to provide *all* of these interfaces, not just one - config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], - auth_handler: AuthHandler, + config: HomeServerConfig, auth_handler: AuthHandler ) -> List[List[str]]: """Get a suitable flows list for registration @@ -929,7 +911,7 @@ def _calculate_registration_flows( return flows -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: EmailRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 0821cd285f..0b0711c03c 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -19,25 +19,32 @@ any time to reflect changes in the MSC. """ import logging +from typing import TYPE_CHECKING, Awaitable, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import ShadowBanError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache from synapse.storage.relations import ( AggregationPaginationToken, PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.stringutils import random_string from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet): "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: http_server.register_paths( "POST", client_patterns(self.PATTERN + "$", releases=()), @@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet): self.__class__.__name__, ) - def on_PUT(self, request, *args, **kwargs): + def on_PUT( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( - request, self.on_PUT_or_POST, request, *args, **kwargs + request, + self.on_PUT_or_POST, + request, + room_id, + parent_id, + relation_type, + event_type, + txn_id, ) async def on_PUT_or_POST( - self, request, room_id, parent_id, relation_type, event_type, txn_id=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if event_type == EventTypes.Member: @@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet): self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet): # This gets the original event and checks that a) the event exists and # b) the user is allowed to view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") limit = parse_integer(request, "limit", default=5) from_token_str = parse_string(request, "from") @@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() async def on_GET( - self, request, room_id, parent_id, relation_type=None, event_type=None - ): + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet): # This checks that a) the event exists and b) the user is allowed to # view it. event = await self.event_handler.get_event(requester.user, room_id, parent_id) + if event is None: + raise SynapseError(404, "Unknown parent event.") if relation_type not in (RelationTypes.ANNOTATION, None): raise SynapseError(400, "Relation type must be 'annotation'") @@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): releases=(), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.event_handler = hs.get_event_handler() - async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + async def on_GET( + self, + request: SynapseRequest, + room_id: str, + parent_id: str, + relation_type: str, + event_type: str, + key: str, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( @@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): return 200, return_value -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationSendServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/report_event.py b/synapse/rest/client/report_event.py index 07ea39a8a3..d4a4adb50c 100644 --- a/synapse/rest/client/report_event.py +++ b/synapse/rest/client/report_event.py @@ -14,26 +14,35 @@ import logging from http import HTTPStatus +from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() - async def on_POST(self, request, room_id, event_id): + async def on_POST( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) user_id = requester.user.to_string() @@ -64,5 +73,5 @@ class ReportEventRestServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReportEventRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index c5c54564be..9b0c546505 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -16,9 +16,11 @@ """ This module contains REST servlets to do with rooms: /rooms/<paths> """ import logging import re -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple from urllib import parse as urlparse +from twisted.web.server import Request + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -30,6 +32,7 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 +from synapse.http.server import HttpServer from synapse.http.servlet import ( ResolveRoomIdMixin, RestServlet, @@ -57,7 +60,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.txns = HttpTransactionCache(hs) @@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet): class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT(self, request, txn_id): + def on_PUT( + self, request: SynapseRequest, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) info, _ = await self._room_creation_handler.create_room( @@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet): return 200, info - def get_room_config(self, request): + def get_room_config(self, request: Request) -> JsonDict: user_supplied_config = parse_json_object_from_request(request) return user_supplied_config # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /room/$roomid/state/$eventtype no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" @@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet): self.__class__.__name__, ) - def on_GET_no_state_key(self, request, room_id, event_type): + def on_GET_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_GET(request, room_id, event_type, "") - def on_PUT_no_state_key(self, request, room_id, event_type): + def on_PUT_no_state_key( + self, request: SynapseRequest, room_id: str, event_type: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.on_PUT(request, room_id, event_type, "") - async def on_GET(self, request, room_id, event_type, state_key): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, state_key: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] @@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + # Format must be event or content, per the parse_string call above. + raise RuntimeError(f"Unknown format: {format:r}.") + + async def on_PUT( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + state_key: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if txn_id: @@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - async def on_POST(self, request, room_id, event_type, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) - event_dict = { + event_dict: JsonDict = { "type": event_type, "content": content, "room_id": room_id, "sender": requester.user.to_string(), } + # Twisted will have processed the args by now. + assert request.args is not None if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) @@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET(self, request, room_id, event_type, txn_id): + def on_GET( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Tuple[int, str]: return 200, "Not implemented" - def on_PUT(self, request, room_id, event_type, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /join/$room_identifier[/$txn_id] PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) @@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): request: SynapseRequest, room_identifier: str, txn_id: Optional[str] = None, - ): + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: @@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): return 200, {"room_id": room_id} - def on_PUT(self, request, room_identifier, txn_id): + def on_PUT( + self, request: SynapseRequest, room_identifier: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: server = parse_string(request, "server") try: @@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): return 200, data - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server") @@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet): class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: # TODO support Pagination stream API (limit/tokens) requester = await self.auth.get_user_by_req(request, allow_guest=True) handler = self.message_handler @@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet): class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) users_with_profile = await self.message_handler.get_joined_members( @@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet): class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request( self.store, request, default_limit=10 ) + # Twisted will have processed the args by now. + assert request.args is not None as_client_event = b"raw" not in request.args filter_str = parse_string(request, "filter", encoding="utf-8") if filter_str: @@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet): class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, List[JsonDict]]: requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room events = await self.message_handler.get_state_events( @@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet): class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() self.store = hs.get_datastore() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( @@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) try: event = await self.event_handler.get_event( @@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = await self._event_serializer.serialize_event(event, time_now) - return 200, event + event_dict = await self._event_serializer.serialize_event(event, time_now) + return 200, event_dict - return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) class RoomEventContextServlet(RestServlet): @@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - async def on_GET(self, request, room_id, event_id): + async def on_GET( + self, request: SynapseRequest, room_id: str, event_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, txn_id=None): + async def on_POST( + self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT(self, request, room_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/[invite|join|leave] PATTERNS = ( "/rooms/(?P<room_id>[^/]*)/" @@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, membership_action, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { @@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def _has_3pid_invite_keys(self, content): + def _has_3pid_invite_keys(self, content: JsonDict) -> bool: for key in {"id_server", "medium", "address"}: if key not in content: return False return True - def on_PUT(self, request, room_id, membership_action, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() - def register(self, http_server): + def register(self, http_server: HttpServer) -> None: PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST(self, request, room_id, event_id, txn_id=None): + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT(self, request, room_id, event_id, txn_id): + def on_PUT( + self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( @@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet): hs.config.worker.writers.typing == hs.get_instance_name() ) - async def on_PUT(self, request, room_id, user_id): + async def on_PUT( + self, request: SynapseRequest, room_id: str, user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) if not self._is_typing_writer: @@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet): self.auth = hs.get_auth() self.directory_handler = hs.get_directory_handler() - async def on_GET(self, request, room_id): + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) alias_list = await self.directory_handler.get_aliases_for_room( @@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet): class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.search_handler = hs.get_search_handler() self.auth = hs.get_auth() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) @@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet): class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} -def register_txn_path(servlet, regex_string, http_server, with_get=False): +def register_txn_path( + servlet: RestServlet, + regex_string: str, + http_server: HttpServer, + with_get: bool = False, +) -> None: """Registers a transaction-based path. This registers two paths: @@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): POST regex_string Args: - regex_string (str): The regex string to register. Must NOT have a - trailing $ as this string will be appended to. - http_server : The http_server to register paths with. + regex_string: The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server: The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ + on_POST = getattr(servlet, "on_POST", None) + on_PUT = getattr(servlet, "on_PUT", None) + if on_POST is None or on_PUT is None: + raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path") http_server.register_paths( "POST", client_patterns(regex_string + "$", v1=True), - servlet.on_POST, + on_POST, servlet.__class__.__name__, ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - servlet.on_PUT, + on_PUT, servlet.__class__.__name__, ) + on_GET = getattr(servlet, "on_GET", None) if with_get: + if on_GET is None: + raise RuntimeError( + "register_txn_path called with with_get = True, but no on_GET method exists" + ) http_server.register_paths( "GET", client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - servlet.on_GET, + on_GET, servlet.__class__.__name__, ) @@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet): ) -def register_servlets(hs: "HomeServer", http_server, is_worker=False): +def register_servlets( + hs: "HomeServer", http_server: HttpServer, is_worker: bool = False +) -> None: RoomStateEventRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server) @@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False): RoomForgetRestServlet(hs).register(http_server) -def register_deprecated_servlets(hs, http_server): +def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomInitialSyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 3172aba605..ed96978448 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -14,10 +14,14 @@ import logging import re +from typing import TYPE_CHECKING, Awaitable, List, Tuple + +from twisted.web.server import Request from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import AuthError, Codes, SynapseError from synapse.appservice import ApplicationService +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -25,10 +29,14 @@ from synapse.http.servlet import ( parse_string, parse_strings_from_args, ) +from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -66,7 +74,7 @@ class RoomBatchSendEventRestServlet(RestServlet): ), ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.store = hs.get_datastore() @@ -76,7 +84,7 @@ class RoomBatchSendEventRestServlet(RestServlet): self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) - async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int: + async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: ( most_recent_prev_event_id, most_recent_prev_event_depth, @@ -118,7 +126,7 @@ class RoomBatchSendEventRestServlet(RestServlet): def _create_insertion_event_dict( self, sender: str, room_id: str, origin_server_ts: int - ): + ) -> JsonDict: """Creates an event dict for an "insertion" event with the proper fields and a random chunk ID. @@ -128,7 +136,7 @@ class RoomBatchSendEventRestServlet(RestServlet): origin_server_ts: Timestamp when the event was sent Returns: - Tuple of event ID and stream ordering position + The new event dictionary to insert. """ next_chunk_id = random_string(8) @@ -164,7 +172,9 @@ class RoomBatchSendEventRestServlet(RestServlet): return create_requester(user_id, app_service=app_service) - async def on_POST(self, request, room_id): + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=False) if not requester.app_service: @@ -176,6 +186,7 @@ class RoomBatchSendEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ["state_events_at_start", "events"]) + assert request.args is not None prev_events_from_query = parse_strings_from_args(request.args, "prev_event") chunk_id_from_query = parse_string(request, "chunk_id") @@ -425,16 +436,18 @@ class RoomBatchSendEventRestServlet(RestServlet): ], } - def on_GET(self, request, room_id): + def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: return 501, "Not implemented" - def on_PUT(self, request, room_id): + def on_PUT( + self, request: SynapseRequest, room_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, room_id ) -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: msc2716_enabled = hs.config.experimental.msc2716_enabled if msc2716_enabled: diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index 263596be86..37e39570f6 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -13,16 +13,23 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional, Tuple from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, parse_string, ) +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -31,16 +38,14 @@ class RoomKeysServlet(RestServlet): "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_PUT(self, request, room_id, session_id): + async def on_PUT( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Uploads one or more encrypted E2E room keys for backup purposes. room_id: the ID of the room the keys are for (optional) @@ -133,7 +138,9 @@ class RoomKeysServlet(RestServlet): ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) return 200, ret - async def on_GET(self, request, room_id, session_id): + async def on_GET( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieves one or more encrypted E2E room keys for backup purposes. Symmetric with the PUT version of the API. @@ -215,7 +222,9 @@ class RoomKeysServlet(RestServlet): return 200, room_keys - async def on_DELETE(self, request, room_id, session_id): + async def on_DELETE( + self, request: SynapseRequest, room_id: Optional[str], session_id: Optional[str] + ) -> Tuple[int, JsonDict]: """ Deletes one or more encrypted E2E room keys for a user for backup purposes. @@ -242,16 +251,12 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: """ Create a new backup version for this user's room_keys with the given info. The version is allocated by the server and returned to the user @@ -295,16 +300,14 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() - async def on_GET(self, request, version): + async def on_GET( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Retrieve the version information about a given version of the user's room_keys backup. If the version part is missing, returns info about the @@ -332,7 +335,9 @@ class RoomKeysVersionServlet(RestServlet): raise SynapseError(404, "No backup found", Codes.NOT_FOUND) return 200, info - async def on_DELETE(self, request, version): + async def on_DELETE( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Delete the information about a given version of the user's room_keys backup. If the version part is missing, deletes the most @@ -351,7 +356,9 @@ class RoomKeysVersionServlet(RestServlet): await self.e2e_room_keys_handler.delete_version(user_id, version) return 200, {} - async def on_PUT(self, request, version): + async def on_PUT( + self, request: SynapseRequest, version: Optional[str] + ) -> Tuple[int, JsonDict]: """ Update the information about a given version of the user's room_keys backup. @@ -385,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet): return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomKeysServlet(hs).register(http_server) RoomKeysVersionServlet(hs).register(http_server) RoomKeysNewVersionServlet(hs).register(http_server) diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index d537d811d8..3322c8ef48 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -13,15 +13,21 @@ # limitations under the License. import logging -from typing import Tuple +from typing import TYPE_CHECKING, Awaitable, Tuple from synapse.http import servlet +from synapse.http.server import HttpServer from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.logging.opentracing import set_tag, trace from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,11 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$" ) - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -42,14 +44,18 @@ class SendToDeviceRestServlet(servlet.RestServlet): self.device_message_handler = hs.get_device_message_handler() @trace(opname="sendToDevice") - def on_PUT(self, request, message_type, txn_id): + def on_PUT( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Awaitable[Tuple[int, JsonDict]]: set_tag("message_type", message_type) set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request( request, self._put, request, message_type, txn_id ) - async def _put(self, request, message_type, txn_id): + async def _put( + self, request: SynapseRequest, message_type: str, txn_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -59,9 +65,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): requester, message_type, content["messages"] ) - response: Tuple[int, dict] = (200, {}) - return response + return 200, {} -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SendToDeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 65c37be3e9..1259058b9b 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -14,12 +14,24 @@ import itertools import logging from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.presence import UserPresenceState +from synapse.events import EventBase from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, @@ -504,7 +516,7 @@ class SyncRestServlet(RestServlet): The room, encoded in our response format """ - def serialize(events): + def serialize(events: Iterable[EventBase]) -> Awaitable[List[JsonDict]]: return self._event_serializer.serialize_events( events, time_now=time_now, diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 94ff3719ce..914fb3acf5 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -15,28 +15,37 @@ """This module contains logic for storing HTTP PUT transactions. This is used to ensure idempotency when performing PUTs using the REST API.""" import logging +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple + +from twisted.python.failure import Failure +from twisted.web.server import Request from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import JsonDict from synapse.util.async_helpers import ObservableDeferred +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = self.hs.get_auth() self.clock = self.hs.get_clock() - self.transactions = { - # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) - } + # $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp) + self.transactions: Dict[ + str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int] + ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) - def _get_transaction_key(self, request): + def _get_transaction_key(self, request: Request) -> str: """A helper function which returns a transaction key that can be used with TransactionCache for idempotent requests. @@ -45,15 +54,21 @@ class HttpTransactionCache: path and the access_token for the requesting user. Args: - request (twisted.web.http.Request): The incoming request. Must - contain an access_token. + request: The incoming request. Must contain an access_token. Returns: - str: A transaction key + A transaction key """ + assert request.path is not None token = self.auth.get_access_token_from_request(request) return request.path.decode("utf8") + "/" + token - def fetch_or_execute_request(self, request, fn, *args, **kwargs): + def fetch_or_execute_request( + self, + request: Request, + fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], + *args: Any, + **kwargs: Any, + ) -> Awaitable[Tuple[int, JsonDict]]: """A helper function for fetch_or_execute which extracts a transaction key from the given request. @@ -64,15 +79,20 @@ class HttpTransactionCache: self._get_transaction_key(request), fn, *args, **kwargs ) - def fetch_or_execute(self, txn_key, fn, *args, **kwargs): + def fetch_or_execute( + self, + txn_key: str, + fn: Callable[..., Awaitable[Tuple[int, JsonDict]]], + *args: Any, + **kwargs: Any, + ) -> Awaitable[Tuple[int, JsonDict]]: """Fetches the response for this transaction, or executes the given function to produce a response for this transaction. Args: - txn_key (str): A key to ensure idempotency should fetch_or_execute be - called again at a later point in time. - fn (function): A function which returns a tuple of - (response_code, response_dict). + txn_key: A key to ensure idempotency should fetch_or_execute be + called again at a later point in time. + fn: A function which returns a tuple of (response_code, response_dict). *args: Arguments to pass to fn. **kwargs: Keyword arguments to pass to fn. Returns: @@ -90,7 +110,7 @@ class HttpTransactionCache: # if the request fails with an exception, remove it # from the transaction map. This is done to ensure that we don't # cache transient errors like rate-limiting errors, etc. - def remove_from_map(err): + def remove_from_map(err: Failure) -> None: self.transactions.pop(txn_key, None) # we deliberately do not propagate the error any further, as we # expect the observers to have reported it. @@ -99,7 +119,7 @@ class HttpTransactionCache: return make_deferred_yieldable(observable.observe()) - def _cleanup(self): + def _cleanup(self) -> None: now = self.clock.time_msec() for key in list(self.transactions): ts = self.transactions[key][1] diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py new file mode 100644 index 0000000000..afe41823e4 --- /dev/null +++ b/synapse/rest/media/v1/oembed.py @@ -0,0 +1,135 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Optional + +import attr + +from synapse.http.client import SimpleHttpClient + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, auto_attribs=True) +class OEmbedResult: + # Either HTML content or URL must be provided. + html: Optional[str] + url: Optional[str] + title: Optional[str] + # Number of seconds to cache the content. + cache_age: int + + +class OEmbedError(Exception): + """An error occurred processing the oEmbed object.""" + + +class OEmbedProvider: + """ + A helper for accessing oEmbed content. + + It can be used to check if a URL should be accessed via oEmbed and for + requesting/parsing oEmbed content. + """ + + def __init__(self, hs: "HomeServer", client: SimpleHttpClient): + self._oembed_patterns = {} + for oembed_endpoint in hs.config.oembed.oembed_patterns: + for pattern in oembed_endpoint.url_patterns: + self._oembed_patterns[pattern] = oembed_endpoint.api_endpoint + self._client = client + + def get_oembed_url(self, url: str) -> Optional[str]: + """ + Check whether the URL should be downloaded as oEmbed content instead. + + Args: + url: The URL to check. + + Returns: + A URL to use instead or None if the original URL should be used. + """ + for url_pattern, endpoint in self._oembed_patterns.items(): + if url_pattern.fullmatch(url): + return endpoint + + # No match. + return None + + async def get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: + """ + Request content from an oEmbed endpoint. + + Args: + endpoint: The oEmbed API endpoint. + url: The URL to pass to the API. + + Returns: + An object representing the metadata returned. + + Raises: + OEmbedError if fetching or parsing of the oEmbed information fails. + """ + try: + logger.debug("Trying to get oEmbed content for url '%s'", url) + result = await self._client.get_json( + endpoint, + # TODO Specify max height / width. + # Note that only the JSON format is supported. + args={"url": url}, + ) + + # Ensure there's a version of 1.0. + if result.get("version") != "1.0": + raise OEmbedError("Invalid version: %s" % (result.get("version"),)) + + oembed_type = result.get("type") + + # Ensure the cache age is None or an int. + cache_age = result.get("cache_age") + if cache_age: + cache_age = int(cache_age) + + oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) + + # HTML content. + if oembed_type == "rich": + oembed_result.html = result.get("html") + return oembed_result + + if oembed_type == "photo": + oembed_result.url = result.get("url") + return oembed_result + + # TODO Handle link and video types. + + if "thumbnail_url" in result: + oembed_result.url = result.get("thumbnail_url") + return oembed_result + + raise OEmbedError("Incompatible oEmbed information.") + + except OEmbedError as e: + # Trap OEmbedErrors first so we can directly re-raise them. + logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) + raise + + except Exception as e: + # Trap any exception and let the code follow as usual. + # FIXME: pass through 404s and other error messages nicely + logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) + raise OEmbedError() from e diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0f051d4041..f108da05db 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -22,7 +22,7 @@ import re import shutil import sys import traceback -from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union +from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Union from urllib import parse as urlparse import attr @@ -43,6 +43,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1.media_storage import MediaStorage +from synapse.rest.media.v1.oembed import OEmbedError, OEmbedProvider +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -71,65 +73,44 @@ OG_TAG_VALUE_MAXLEN = 1000 ONE_HOUR = 60 * 60 * 1000 -# A map of globs to API endpoints. -_oembed_globs = { - # Twitter. - "https://publish.twitter.com/oembed": [ - "https://twitter.com/*/status/*", - "https://*.twitter.com/*/status/*", - "https://twitter.com/*/moments/*", - "https://*.twitter.com/*/moments/*", - # Include the HTTP versions too. - "http://twitter.com/*/status/*", - "http://*.twitter.com/*/status/*", - "http://twitter.com/*/moments/*", - "http://*.twitter.com/*/moments/*", - ], -} -# Convert the globs to regular expressions. -_oembed_patterns = {} -for endpoint, globs in _oembed_globs.items(): - for glob in globs: - # Convert the glob into a sane regular expression to match against. The - # rules followed will be slightly different for the domain portion vs. - # the rest. - # - # 1. The scheme must be one of HTTP / HTTPS (and have no globs). - # 2. The domain can have globs, but we limit it to characters that can - # reasonably be a domain part. - # TODO: This does not attempt to handle Unicode domain names. - # 3. Other parts allow a glob to be any one, or more, characters. - results = urlparse.urlparse(glob) - - # Ensure the scheme does not have wildcards (and is a sane scheme). - if results.scheme not in {"http", "https"}: - raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,)) - - pattern = urlparse.urlunparse( - [ - results.scheme, - re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"), - ] - + [re.escape(part).replace("\\*", ".+") for part in results[2:]] - ) - _oembed_patterns[re.compile(pattern)] = endpoint +@attr.s(slots=True, frozen=True, auto_attribs=True) +class MediaInfo: + """ + Information parsed from downloading media being previewed. + """ -@attr.s(slots=True) -class OEmbedResult: - # Either HTML content or URL must be provided. - html = attr.ib(type=Optional[str]) - url = attr.ib(type=Optional[str]) - title = attr.ib(type=Optional[str]) - # Number of seconds to cache the content. - cache_age = attr.ib(type=int) + # The Content-Type header of the response. + media_type: str + # The length (in bytes) of the downloaded media. + media_length: int + # The media filename, according to the server. This is parsed from the + # returned headers, if possible. + download_name: Optional[str] + # The time of the preview. + created_ts_ms: int + # Information from the media storage provider about where the file is stored + # on disk. + filesystem_id: str + filename: str + # The URI being previewed. + uri: str + # The HTTP response code. + response_code: int + # The timestamp (in milliseconds) of when this preview expires. + expires: int + # The ETag header of the response. + etag: Optional[str] -class OEmbedError(Exception): - """An error occurred processing the oEmbed object.""" +class PreviewUrlResource(DirectServeJsonResource): + """ + Generating URL previews is a complicated task which many potential pitfalls. + See docs/development/url_previews.md for discussion of the design and + algorithm followed in this module. + """ -class PreviewUrlResource(DirectServeJsonResource): isLeaf = True def __init__( @@ -157,6 +138,8 @@ class PreviewUrlResource(DirectServeJsonResource): self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage + self._oembed = OEmbedProvider(hs, self.client) + # We run the background jobs if we're the instance specified (or no # instance is specified, where we assume there is only one instance # serving media). @@ -275,18 +258,17 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("got media_info of '%s'", media_info) - if _is_media(media_info["media_type"]): - file_id = media_info["filesystem_id"] + if _is_media(media_info.media_type): + file_id = media_info.filesystem_id dims = await self.media_repo._generate_thumbnails( - None, file_id, file_id, media_info["media_type"], url_cache=True + None, file_id, file_id, media_info.media_type, url_cache=True ) og = { - "og:description": media_info["download_name"], - "og:image": "mxc://%s/%s" - % (self.server_name, media_info["filesystem_id"]), - "og:image:type": media_info["media_type"], - "matrix:image:size": media_info["media_length"], + "og:description": media_info.download_name, + "og:image": f"mxc://{self.server_name}/{media_info.filesystem_id}", + "og:image:type": media_info.media_type, + "matrix:image:size": media_info.media_length, } if dims: @@ -296,14 +278,14 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Couldn't get dims for %s" % url) # define our OG response for this media - elif _is_html(media_info["media_type"]): + elif _is_html(media_info.media_type): # TODO: somehow stop a big HTML tree from exploding synapse's RAM - with open(media_info["filename"], "rb") as file: + with open(media_info.filename, "rb") as file: body = file.read() - encoding = get_html_media_encoding(body, media_info["media_type"]) - og = decode_and_calc_og(body, media_info["uri"], encoding) + encoding = get_html_media_encoding(body, media_info.media_type) + og = decode_and_calc_og(body, media_info.uri, encoding) # pre-cache the image for posterity # FIXME: it might be cleaner to use the same flow as the main /preview_url @@ -311,14 +293,14 @@ class PreviewUrlResource(DirectServeJsonResource): # just rely on the caching on the master request to speed things up. if "og:image" in og and og["og:image"]: image_info = await self._download_url( - _rebase_url(og["og:image"], media_info["uri"]), user + _rebase_url(og["og:image"], media_info.uri), user ) - if _is_media(image_info["media_type"]): + if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images - file_id = image_info["filesystem_id"] + file_id = image_info.filesystem_id dims = await self.media_repo._generate_thumbnails( - None, file_id, file_id, image_info["media_type"], url_cache=True + None, file_id, file_id, image_info.media_type, url_cache=True ) if dims: og["og:image:width"] = dims["width"] @@ -326,12 +308,11 @@ class PreviewUrlResource(DirectServeJsonResource): else: logger.warning("Couldn't get dims for %s", og["og:image"]) - og["og:image"] = "mxc://%s/%s" % ( - self.server_name, - image_info["filesystem_id"], - ) - og["og:image:type"] = image_info["media_type"] - og["matrix:image:size"] = image_info["media_length"] + og[ + "og:image" + ] = f"mxc://{self.server_name}/{image_info.filesystem_id}" + og["og:image:type"] = image_info.media_type + og["matrix:image:size"] = image_info.media_length else: del og["og:image"] else: @@ -357,98 +338,17 @@ class PreviewUrlResource(DirectServeJsonResource): # store OG in history-aware DB cache await self.store.store_url_cache( url, - media_info["response_code"], - media_info["etag"], - media_info["expires"] + media_info["created_ts"], + media_info.response_code, + media_info.etag, + media_info.expires + media_info.created_ts_ms, jsonog, - media_info["filesystem_id"], - media_info["created_ts"], + media_info.filesystem_id, + media_info.created_ts_ms, ) return jsonog.encode("utf8") - def _get_oembed_url(self, url: str) -> Optional[str]: - """ - Check whether the URL should be downloaded as oEmbed content instead. - - Args: - url: The URL to check. - - Returns: - A URL to use instead or None if the original URL should be used. - """ - for url_pattern, endpoint in _oembed_patterns.items(): - if url_pattern.fullmatch(url): - return endpoint - - # No match. - return None - - async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult: - """ - Request content from an oEmbed endpoint. - - Args: - endpoint: The oEmbed API endpoint. - url: The URL to pass to the API. - - Returns: - An object representing the metadata returned. - - Raises: - OEmbedError if fetching or parsing of the oEmbed information fails. - """ - try: - logger.debug("Trying to get oEmbed content for url '%s'", url) - result = await self.client.get_json( - endpoint, - # TODO Specify max height / width. - # Note that only the JSON format is supported. - args={"url": url}, - ) - - # Ensure there's a version of 1.0. - if result.get("version") != "1.0": - raise OEmbedError("Invalid version: %s" % (result.get("version"),)) - - oembed_type = result.get("type") - - # Ensure the cache age is None or an int. - cache_age = result.get("cache_age") - if cache_age: - cache_age = int(cache_age) - - oembed_result = OEmbedResult(None, None, result.get("title"), cache_age) - - # HTML content. - if oembed_type == "rich": - oembed_result.html = result.get("html") - return oembed_result - - if oembed_type == "photo": - oembed_result.url = result.get("url") - return oembed_result - - # TODO Handle link and video types. - - if "thumbnail_url" in result: - oembed_result.url = result.get("thumbnail_url") - return oembed_result - - raise OEmbedError("Incompatible oEmbed information.") - - except OEmbedError as e: - # Trap OEmbedErrors first so we can directly re-raise them. - logger.warning("Error parsing oEmbed metadata from %s: %r", url, e) - raise - - except Exception as e: - # Trap any exception and let the code follow as usual. - # FIXME: pass through 404s and other error messages nicely - logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) - raise OEmbedError() from e - - async def _download_url(self, url: str, user: str) -> Dict[str, Any]: + async def _download_url(self, url: str, user: str) -> MediaInfo: # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -459,11 +359,11 @@ class PreviewUrlResource(DirectServeJsonResource): # If this URL can be accessed via oEmbed, use that instead. url_to_download: Optional[str] = url - oembed_url = self._get_oembed_url(url) + oembed_url = self._oembed.get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. try: - oembed_result = await self._get_oembed_content(oembed_url, url) + oembed_result = await self._oembed.get_oembed_content(oembed_url, url) if oembed_result.url: url_to_download = oembed_result.url elif oembed_result.html: @@ -560,18 +460,18 @@ class PreviewUrlResource(DirectServeJsonResource): # therefore not expire it. raise - return { - "media_type": media_type, - "media_length": length, - "download_name": download_name, - "created_ts": time_now_ms, - "filesystem_id": file_id, - "filename": fname, - "uri": uri, - "response_code": code, - "expires": expires, - "etag": etag, - } + return MediaInfo( + media_type=media_type, + media_length=length, + download_name=download_name, + created_ts_ms=time_now_ms, + filesystem_id=file_id, + filename=fname, + uri=uri, + response_code=code, + expires=expires, + etag=etag, + ) def _start_expire_url_cache_data(self): return run_as_background_process( @@ -717,7 +617,7 @@ def get_html_media_encoding(body: bytes, content_type: str) -> str: def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None -) -> Dict[str, Optional[str]]: +) -> JsonDict: """ Calculate metadata for an HTML document. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 95d2caff62..0084d9f96c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -280,18 +280,18 @@ class LoggingTransaction: else: self.executemany(sql, args) - def execute_values(self, sql: str, *args: Any) -> List[Tuple]: + def execute_values(self, sql: str, *args: Any, fetch: bool = True) -> List[Tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when using postgres. - Always sets fetch=True when caling `execute_values`, so will return the - results. + The `fetch` parameter must be set to False if the query does not return + rows (e.g. INSERTs). """ assert isinstance(self.database_engine, PostgresEngine) from psycopg2.extras import execute_values # type: ignore return self._do_execute( - lambda *x: execute_values(self.txn, *x, fetch=True), sql, *args + lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args ) def execute(self, sql: str, *args: Any) -> None: @@ -920,13 +920,23 @@ class DatabasePool: if k != keys[0]: raise RuntimeError("All items must have the same keys") - sql = "INSERT INTO %s (%s) VALUES(%s)" % ( - table, - ", ".join(k for k in keys[0]), - ", ".join("?" for _ in keys[0]), - ) + if isinstance(txn.database_engine, PostgresEngine): + # We use `execute_values` as it can be a lot faster than `execute_batch`, + # but it's only available on postgres. + sql = "INSERT INTO %s (%s) VALUES ?" % ( + table, + ", ".join(k for k in keys[0]), + ) - txn.execute_batch(sql, vals) + txn.execute_values(sql, vals, fetch=False) + else: + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in keys[0]), + ", ".join("?" for _ in keys[0]), + ) + + txn.execute_batch(sql, vals) async def simple_upsert( self, @@ -1281,20 +1291,33 @@ class DatabasePool: k + "=EXCLUDED." + k for k in value_names ) - sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( - table, - ", ".join(k for k in allnames), - ", ".join("?" for _ in allnames), - ", ".join(key_names), - latter, - ) - args = [] for x, y in zip(key_values, value_values): args.append(tuple(x) + tuple(y)) - return txn.execute_batch(sql, args) + if isinstance(txn.database_engine, PostgresEngine): + # We use `execute_values` as it can be a lot faster than `execute_batch`, + # but it's only available on postgres. + sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join(key_names), + latter, + ) + + txn.execute_values(sql, args, fetch=False) + + else: + sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % ( + table, + ", ".join(k for k in allnames), + ", ".join("?" for _ in allnames), + ", ".join(key_names), + latter, + ) + + return txn.execute_batch(sql, args) @overload async def simple_select_one( diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 86075bc55b..6daf8b8ffb 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -75,8 +75,6 @@ class DirectoryWorkerStore(SQLBaseStore): desc="get_aliases_for_room", ) - -class DirectoryStore(DirectoryWorkerStore): async def create_room_alias_association( self, room_alias: RoomAlias, @@ -126,6 +124,8 @@ class DirectoryStore(DirectoryWorkerStore): 409, "Room alias %s already exists" % room_alias.to_string() ) + +class DirectoryStore(DirectoryWorkerStore): async def delete_room_alias(self, room_alias: RoomAlias) -> str: room_id = await self.db_pool.runInteraction( "delete_room_alias", self._delete_room_alias_txn, room_alias diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 40b53274fb..f07e288056 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -575,7 +575,13 @@ class PersistEventsStore: missing_auth_chains.clear() - for auth_id, event_type, state_key, chain_id, sequence_number in txn: + for ( + auth_id, + event_type, + state_key, + chain_id, + sequence_number, + ) in txn.fetchall(): event_to_types[auth_id] = (event_type, state_key) if chain_id is None: @@ -1379,18 +1385,18 @@ class PersistEventsStore: # If we're persisting an unredacted event we go and ensure # that we mark any redactions that reference this event as # requiring censoring. - sql = "UPDATE redactions SET have_censored = ? WHERE redacts = ?" - txn.execute_batch( - sql, - ( - ( - False, - event.event_id, - ) - for event, _ in events_and_contexts - if not event.internal_metadata.is_redacted() - ), + unredacted_events = [ + event.event_id + for event, _ in events_and_contexts + if not event.internal_metadata.is_redacted() + ] + sql = "UPDATE redactions SET have_censored = ? WHERE " + clause, args = make_in_list_sql_clause( + self.database_engine, + "redacts", + unredacted_events, ) + txn.execute(sql + clause, [False] + args) state_events_and_contexts = [ ec for ec in events_and_contexts if ec[0].is_state() @@ -1770,10 +1776,21 @@ class PersistEventsStore: # Not a insertion event return - # Skip processing a insertion event if the room version doesn't - # support it. + # Skip processing an insertion event if the room version doesn't + # support it or the event is not from the room creator. room_version = self.store.get_room_version_txn(txn, event.room_id) - if not room_version.msc2716_historical: + room_creator = self.db_pool.simple_select_one_onecol_txn( + txn, + table="rooms", + keyvalues={"room_id": event.room_id}, + retcol="creator", + allow_none=True, + ) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or event.sender != room_creator + ): return next_chunk_id = event.content.get(EventContentFields.MSC2716_NEXT_CHUNK_ID) @@ -1822,9 +1839,20 @@ class PersistEventsStore: return # Skip processing a chunk event if the room version doesn't - # support it. + # support it or the event is not from the room creator. room_version = self.store.get_room_version_txn(txn, event.room_id) - if not room_version.msc2716_historical: + room_creator = self.db_pool.simple_select_one_onecol_txn( + txn, + table="rooms", + keyvalues={"room_id": event.room_id}, + retcol="creator", + allow_none=True, + ) + if ( + not room_version.msc2716_historical + or not self.hs.config.experimental.msc2716_enabled + or event.sender != room_creator + ): return chunk_id = event.content.get(EventContentFields.MSC2716_CHUNK_ID) diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 1388771c40..12cf6995eb 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -29,7 +29,26 @@ if TYPE_CHECKING: from synapse.server import HomeServer -class PresenceStore(SQLBaseStore): +class PresenceBackgroundUpdateStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: Connection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + # Used by `PresenceStore._get_active_presence()` + self.db_pool.updates.register_background_index_update( + "presence_stream_not_offline_index", + index_name="presence_stream_state_not_offline_idx", + table="presence_stream", + columns=["state"], + where_clause="state != 'offline'", + ) + + +class PresenceStore(PresenceBackgroundUpdateStore): def __init__( self, database: DatabasePool, @@ -332,6 +351,8 @@ class PresenceStore(SQLBaseStore): the appropriate time outs. """ + # The `presence_stream_state_not_offline_idx` index should be used for this + # query. sql = ( "SELECT user_id, state, last_active_ts, last_federation_update_ts," " last_user_sync_ts, status_msg, currently_active FROM presence_stream" diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index f98b892598..6e7312266d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -19,9 +19,10 @@ from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from synapse.api.constants import EventTypes, JoinRules +from synapse.api.constants import EventContentFields, EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.search import SearchStore @@ -1013,6 +1014,7 @@ class _BackgroundUpdates: ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column" POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2" REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth" + POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column" _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( @@ -1054,6 +1056,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_replace_room_depth_min_depth, ) + self.db_pool.updates.register_background_update_handler( + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + self._background_populate_rooms_creator_column, + ) + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. @@ -1273,7 +1280,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): keyvalues={"room_id": room_id}, retcol="MAX(stream_ordering)", allow_none=True, - desc="upsert_room_on_join", + desc="has_auth_chain_index_fallback", ) return max_ordering is None @@ -1343,6 +1350,65 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return 0 + async def _background_populate_rooms_creator_column( + self, progress: dict, batch_size: int + ): + """Background update to go and add creator information to `rooms` + table from `current_state_events` table. + """ + + last_room_id = progress.get("room_id", "") + + def _background_populate_rooms_creator_column_txn(txn: LoggingTransaction): + sql = """ + SELECT room_id, json FROM event_json + INNER JOIN rooms AS room USING (room_id) + INNER JOIN current_state_events AS state_event USING (room_id, event_id) + WHERE room_id > ? AND (room.creator IS NULL OR room.creator = '') AND state_event.type = 'm.room.create' AND state_event.state_key = '' + ORDER BY room_id + LIMIT ? + """ + + txn.execute(sql, (last_room_id, batch_size)) + room_id_to_create_event_results = txn.fetchall() + + new_last_room_id = "" + for room_id, event_json in room_id_to_create_event_results: + event_dict = db_to_json(event_json) + + creator = event_dict.get("content").get(EventContentFields.ROOM_CREATOR) + + self.db_pool.simple_update_txn( + txn, + table="rooms", + keyvalues={"room_id": room_id}, + updatevalues={"creator": creator}, + ) + new_last_room_id = room_id + + if new_last_room_id == "": + return True + + self.db_pool.updates._background_update_progress_txn( + txn, + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN, + {"room_id": new_last_room_id}, + ) + + return False + + end = await self.db_pool.runInteraction( + "_background_populate_rooms_creator_column", + _background_populate_rooms_creator_column_txn, + ) + + if end: + await self.db_pool.updates._end_background_update( + _BackgroundUpdates.POPULATE_ROOMS_CREATOR_COLUMN + ) + + return batch_size + class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def __init__(self, database: DatabasePool, db_conn, hs): @@ -1350,7 +1416,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): self.config = hs.config - async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion): + async def upsert_room_on_join( + self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase] + ): """Ensure that the room is stored in the table Called when we join a room over federation, and overwrites any room version @@ -1361,6 +1429,24 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): # mark the room as having an auth chain cover index. has_auth_chain_index = await self.has_auth_chain_index(room_id) + create_event = None + for e in auth_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + if create_event is None: + # If the state doesn't have a create event then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No create event in state") + + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + + if not isinstance(room_creator, str): + # If the create event does not have a creator then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No creator defined on the create event") + await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", @@ -1368,7 +1454,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): values={"room_version": room_version.identifier}, insertion_values={ "is_public": False, - "creator": "", + "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, # rooms has a unique constraint on room_id, so no need to lock when doing an @@ -1396,6 +1482,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): insertion_values={ "room_version": room_version.identifier, "is_public": False, + # We don't worry about setting the `creator` here because + # we don't process any messages in a room while a user is + # invited (only after the join). "creator": "", "has_auth_chain_index": has_auth_chain_index, }, diff --git a/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql new file mode 100644 index 0000000000..f7c0b31261 --- /dev/null +++ b/synapse/storage/schema/main/delta/63/02populate-rooms-creator.sql @@ -0,0 +1,17 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) + VALUES (6302, 'populate_rooms_creator_column', '{}'); diff --git a/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql new file mode 100644 index 0000000000..b90856004b --- /dev/null +++ b/synapse/storage/schema/main/delta/63/04add_presence_stream_not_offline_index.sql @@ -0,0 +1,18 @@ +/* + * Copyright 2021 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (6304, 'presence_stream_not_offline_index', '{}'); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index c768fdea56..6f7cbe40f4 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -19,6 +19,7 @@ from contextlib import contextmanager from typing import Dict, Iterable, List, Optional, Set, Tuple, Union import attr +from sortedcontainers import SortedSet from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction @@ -240,7 +241,7 @@ class MultiWriterIdGenerator: # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). - self._unfinished_ids: Set[int] = set() + self._unfinished_ids: SortedSet[int] = SortedSet() # Set of local IDs that we've processed that are larger than the current # position, due to there being smaller unpersisted IDs. @@ -473,7 +474,7 @@ class MultiWriterIdGenerator: finished = set() - min_unfinshed = min(self._unfinished_ids) + min_unfinshed = self._unfinished_ids[0] for s in self._finished_ids: if s < min_unfinshed: if new_cur is None or new_cur < s: diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py index 522daa323d..cfb5b94ca9 100644 --- a/synapse/util/manhole.py +++ b/synapse/util/manhole.py @@ -61,7 +61,7 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs= -----END RSA PRIVATE KEY-----""" -def manhole(username, password, globals): +def manhole(settings, globals): """Starts a ssh listener with password authentication using the given username and password. Clients connecting to the ssh listener will find themselves in a colored python shell with @@ -75,6 +75,15 @@ def manhole(username, password, globals): Returns: twisted.internet.protocol.Factory: A factory to pass to ``listenTCP`` """ + username = settings.username + password = settings.password + priv_key = settings.priv_key + if priv_key is None: + priv_key = Key.fromString(PRIVATE_KEY) + pub_key = settings.pub_key + if pub_key is None: + pub_key = Key.fromString(PUBLIC_KEY) + if not isinstance(password, bytes): password = password.encode("ascii") @@ -86,8 +95,8 @@ def manhole(username, password, globals): ) factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) - factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY) - factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY) + factory.privateKeys[b"ssh-rsa"] = priv_key + factory.publicKeys[b"ssh-rsa"] = pub_key return factory |