summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py17
-rw-r--r--synapse/api/errors.py3
-rw-r--r--synapse/api/room_versions.py27
-rw-r--r--synapse/app/_base.py3
-rw-r--r--synapse/app/generic_worker.py2
-rw-r--r--synapse/app/phone_stats_home.py9
-rw-r--r--synapse/appservice/api.py11
-rw-r--r--synapse/config/_base.py45
-rw-r--r--synapse/config/_base.pyi6
-rw-r--r--synapse/config/account_validity.py149
-rw-r--r--synapse/config/api.py23
-rw-r--r--synapse/config/captcha.py4
-rw-r--r--synapse/config/cas.py12
-rw-r--r--synapse/config/consent_config.py2
-rw-r--r--synapse/config/emailconfig.py2
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/homeserver.py3
-rw-r--r--synapse/config/oidc_config.py3
-rw-r--r--synapse/config/ratelimiting.py41
-rw-r--r--synapse/config/registration.py258
-rw-r--r--synapse/config/repository.py30
-rw-r--r--synapse/config/server.py74
-rw-r--r--synapse/config/user_directory.py18
-rw-r--r--synapse/crypto/context_factory.py13
-rw-r--r--synapse/event_auth.py24
-rw-r--r--synapse/events/spamcheck.py70
-rw-r--r--synapse/events/utils.py19
-rw-r--r--synapse/federation/federation_client.py62
-rw-r--r--synapse/federation/federation_server.py71
-rw-r--r--synapse/federation/sender/__init__.py50
-rw-r--r--synapse/federation/transport/client.py69
-rw-r--r--synapse/federation/transport/server.py88
-rw-r--r--synapse/handlers/account_validity.py127
-rw-r--r--synapse/handlers/acme.py12
-rw-r--r--synapse/handlers/acme_issuing_service.py27
-rw-r--r--synapse/handlers/auth.py10
-rw-r--r--synapse/handlers/cas_handler.py6
-rw-r--r--synapse/handlers/deactivate_account.py5
-rw-r--r--synapse/handlers/device.py12
-rw-r--r--synapse/handlers/e2e_keys.py223
-rw-r--r--synapse/handlers/e2e_room_keys.py91
-rw-r--r--synapse/handlers/federation.py227
-rw-r--r--synapse/handlers/groups_local.py83
-rw-r--r--synapse/handlers/identity.py300
-rw-r--r--synapse/handlers/message.py59
-rw-r--r--synapse/handlers/profile.py213
-rw-r--r--synapse/handlers/register.py124
-rw-r--r--synapse/handlers/room.py32
-rw-r--r--synapse/handlers/room_list.py1
-rw-r--r--synapse/handlers/room_member.py258
-rw-r--r--synapse/handlers/room_member_worker.py50
-rw-r--r--synapse/handlers/search.py38
-rw-r--r--synapse/handlers/set_password.py13
-rw-r--r--synapse/handlers/state_deltas.py14
-rw-r--r--synapse/handlers/stats.py45
-rw-r--r--synapse/handlers/sync.py83
-rw-r--r--synapse/handlers/typing.py69
-rw-r--r--synapse/handlers/user_directory.py9
-rw-r--r--synapse/http/client.py10
-rw-r--r--synapse/http/connectproxyclient.py96
-rw-r--r--synapse/http/proxyagent.py117
-rw-r--r--synapse/http/servlet.py67
-rw-r--r--synapse/logging/opentracing.py2
-rw-r--r--synapse/push/baserules.py12
-rw-r--r--synapse/push/mailer.py18
-rw-r--r--synapse/push/presentable_names.py26
-rw-r--r--synapse/push/pusherpool.py6
-rw-r--r--synapse/python_dependencies.py2
-rw-r--r--synapse/replication/http/membership.py131
-rw-r--r--synapse/replication/tcp/external_cache.py105
-rw-r--r--synapse/replication/tcp/handler.py21
-rw-r--r--synapse/replication/tcp/redis.py143
-rw-r--r--synapse/res/templates/account_previously_renewed.html1
-rw-r--r--synapse/res/templates/account_renewed.html2
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html2
-rw-r--r--synapse/res/templates/sso_error.html2
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html12
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html10
-rw-r--r--synapse/rest/__init__.py3
-rw-r--r--synapse/rest/admin/__init__.py6
-rw-r--r--synapse/rest/admin/rooms.py80
-rw-r--r--synapse/rest/admin/users.py57
-rw-r--r--synapse/rest/client/v1/presence.py4
-rw-r--r--synapse/rest/client/v1/profile.py41
-rw-r--r--synapse/rest/client/v1/room.py13
-rw-r--r--synapse/rest/client/v2_alpha/account.py230
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py14
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py30
-rw-r--r--synapse/rest/client/v2_alpha/knock.py103
-rw-r--r--synapse/rest/client/v2_alpha/register.py254
-rw-r--r--synapse/rest/client/v2_alpha/sync.py81
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py142
-rw-r--r--synapse/rest/client/versions.py5
-rw-r--r--synapse/rest/media/v1/_base.py3
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py47
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py236
-rw-r--r--synapse/rulecheck/__init__.py0
-rw-r--r--synapse/rulecheck/domain_rule_checker.py181
-rw-r--r--synapse/server.py40
-rw-r--r--synapse/state/__init__.py11
-rw-r--r--synapse/storage/database.py6
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py4
-rw-r--r--synapse/storage/databases/main/event_push_actions.py4
-rw-r--r--synapse/storage/databases/main/events.py199
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py12
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py101
-rw-r--r--synapse/storage/databases/main/events_worker.py2
-rw-r--r--synapse/storage/databases/main/media_repository.py10
-rw-r--r--synapse/storage/databases/main/metrics.py56
-rw-r--r--synapse/storage/databases/main/profile.py151
-rw-r--r--synapse/storage/databases/main/purge_events.py2
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/registration.py164
-rw-r--r--synapse/storage/databases/main/room.py37
-rw-r--r--synapse/storage/databases/main/roommember.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/48/profiles_batch.sql36
-rw-r--r--synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql23
-rw-r--r--synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/59/01ignored_user.py2
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres15
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite4
-rw-r--r--synapse/storage/databases/main/search.py7
-rw-r--r--synapse/storage/databases/main/stats.py23
-rw-r--r--synapse/storage/databases/main/user_directory.py61
-rw-r--r--synapse/storage/databases/state/store.py4
-rw-r--r--synapse/storage/util/id_generators.py21
-rw-r--r--synapse/storage/util/sequence.py16
-rw-r--r--synapse/third_party_rules/__init__.py14
-rw-r--r--synapse/third_party_rules/access_rules.py971
-rw-r--r--synapse/types.py14
-rw-r--r--synapse/util/module_loader.py3
-rw-r--r--synapse/util/threepids.py31
136 files changed, 6587 insertions, 1187 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py

index 67ecbd32ff..b575d85976 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -79,7 +79,7 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity = hs.config.account_validity + self._account_validity_enabled = hs.config.account_validity_enabled self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key @@ -192,7 +192,7 @@ class Auth: access_token = self.get_access_token_from_request(request) - user_id, app_service = await self._get_appservice_user_id(request) + user_id, app_service = self._get_appservice_user_id(request) if user_id: if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( @@ -222,7 +222,7 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity.enabled and not allow_expired: + if self._account_validity_enabled and not allow_expired: if await self.store.is_account_expired( user_info.user_id, self.clock.time_msec() ): @@ -268,10 +268,11 @@ class Auth: except KeyError: raise MissingClientTokenError() - async def _get_appservice_user_id(self, request): + def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) + if app_service is None: return None, None @@ -289,8 +290,12 @@ class Auth: if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (await self.store.get_user_by_id(user_id)): - raise AuthError(403, "Application service has not registered this user") + # Let ASes manipulate nonexistent users (e.g. to shadow-register them) + # if not (yield self.store.get_user_by_id(user_id)): + # raise AuthError( + # 403, + # "Application service has not registered this user" + # ) return user_id, app_service async def get_user_by_access_token( diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cd6670d0a2..90bb01f414 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index de2cc15d33..139fbf5524 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py
@@ -57,7 +57,7 @@ class RoomVersion: state_res = attr.ib(type=int) # one of the StateResolutionVersions enforce_key_validity = attr.ib(type=bool) - # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules + # Before MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool) # Strictly enforce canonicaljson, do not allow: # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] @@ -69,6 +69,11 @@ class RoomVersion: limit_notifications_power_levels = attr.ib(type=bool) # MSC2174/MSC2176: Apply updated redaction rules algorithm. msc2176_redaction_rules = attr.ib(type=bool) + # MSC2174/MSC2176: Apply updated redaction rules algorithm. + msc2176_redaction_rules = attr.ib(type=bool) + # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending + # m.room.membership event with membership 'knock'. + allow_knocking = attr.ib(type=bool) class RoomVersions: @@ -82,6 +87,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V2 = RoomVersion( "2", @@ -93,6 +99,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V3 = RoomVersion( "3", @@ -104,6 +111,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V4 = RoomVersion( "4", @@ -115,6 +123,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V5 = RoomVersion( "5", @@ -126,6 +135,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V6 = RoomVersion( "6", @@ -137,6 +147,19 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, + allow_knocking=False, + ) + V7 = RoomVersion( + "7", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + allow_knocking=True, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -148,6 +171,7 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=True, + allow_knocking=False, ) @@ -160,6 +184,7 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, + RoomVersions.V7, RoomVersions.MSC2176, ) } # type: Dict[str, RoomVersion] diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 395e202b89..9840a9d55b 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py
@@ -16,6 +16,7 @@ import gc import logging import os +import platform import signal import socket import sys @@ -339,7 +340,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon # rest of time. Doing so means less work each GC (hopefully). # # This only works on Python 3.7 - if sys.version_info >= (3, 7): + if platform.python_implementation() == "CPython" and sys.version_info >= (3, 7): gc.collect() gc.freeze() diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 516f2464b4..e363d681fd 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -165,7 +165,7 @@ class PresenceStatusStubServlet(RestServlet): async def on_GET(self, request, user_id): await self.auth.get_user_by_req(request) - return 200, {"presence": "offline"} + return 200, {"presence": "offline", "user_id": user_id} async def on_PUT(self, request, user_id): await self.auth.get_user_by_req(request) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index c38cf8231f..8f86cecb76 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py
@@ -93,15 +93,20 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): stats["daily_active_users"] = await hs.get_datastore().count_daily_users() stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users() + daily_active_e2ee_rooms = await hs.get_datastore().count_daily_active_e2ee_rooms() + stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms + stats["daily_e2ee_messages"] = await hs.get_datastore().count_daily_e2ee_messages() + daily_sent_e2ee_messages = await hs.get_datastore().count_daily_sent_e2ee_messages() + stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms() stats["daily_messages"] = await hs.get_datastore().count_daily_messages() + daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages r30_results = await hs.get_datastore().count_r30_users() for name, count in r30_results.items(): stats["r30_users_" + name] = count - daily_sent_messages = await hs.get_datastore().count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages stats["cache_factor"] = hs.config.caches.global_factor stats["event_cache_size"] = hs.config.caches.event_cache_size diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e366a982b8..5dcd9ea290 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter -from synapse.api.constants import EventTypes, ThirdPartyEntityKind +from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.events import EventBase from synapse.events.utils import serialize_event @@ -249,9 +249,14 @@ class ApplicationServiceApi(SimpleHttpClient): e, time_now, as_client_event=True, - is_invite=( + # If this is an invite or a knock membership event, and we're interested + # in this user, then include any stripped state alongside the event. + include_stripped_room_state=( e.type == EventTypes.Member - and e.membership == "invite" + and ( + e.membership == Membership.INVITE + or e.membership == Membership.KNOCK + ) and service.is_interested_in_user(e.state_key) ), ) diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 35e5594b73..c629990dd4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -20,6 +20,7 @@ import errno import os from collections import OrderedDict from hashlib import sha256 +from io import open as io_open from textwrap import dedent from typing import Any, Iterable, List, MutableMapping, Optional @@ -200,14 +201,31 @@ class Config: @classmethod def read_file(cls, file_path, config_name): cls.check_file(file_path, config_name) - with open(file_path) as file_stream: + with io_open(file_path, encoding="utf-8") as file_stream: return file_stream.read() + def read_template(self, filename: str) -> jinja2.Template: + """Load a template file from disk. + + This function will attempt to load the given template from the default Synapse + template directory. + + Files read are treated as Jinja templates. The templates is not rendered yet + and has autoescape enabled. + + Args: + filename: A template filename to read. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A jinja2 template. + """ + return self.read_templates([filename])[0] + def read_templates( - self, - filenames: List[str], - custom_template_directory: Optional[str] = None, - autoescape: bool = False, + self, filenames: List[str], custom_template_directory: Optional[str] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -215,7 +233,8 @@ class Config: template directory. If `custom_template_directory` is supplied, that directory is tried first. - Files read are treated as Jinja templates. These templates are not rendered yet. + Files read are treated as Jinja templates. The templates are not rendered yet + and have autoescape enabled. Args: filenames: A list of template filenames to read. @@ -223,16 +242,12 @@ class Config: custom_template_directory: A directory to try to look for the templates before using the default Synapse template directory instead. - autoescape: Whether to autoescape variables before inserting them into the - template. - Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. Returns: A list of jinja2 templates. """ - templates = [] search_directories = [self.default_template_dir] # The loader will first look in the custom template directory (if specified) for the @@ -250,7 +265,7 @@ class Config: # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=autoescape) + env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) # Update the environment with our custom filters env.filters.update( @@ -260,12 +275,8 @@ class Config: } ) - for filename in filenames: - # Load the template - template = env.get_template(filename) - templates.append(template) - - return templates + # Load the templates + return [env.get_template(filename) for filename in filenames] class RootConfig: diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 3ccea4b02d..0565418e60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -1,6 +1,7 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( + account_validity, api, appservice, auth, @@ -19,6 +20,7 @@ from synapse.config import ( password_auth_providers, push, ratelimiting, + redis, registration, repository, room_directory, @@ -53,11 +55,12 @@ class RootConfig: tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig - ratelimit: ratelimiting.RatelimitConfig + ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig + account_validity: account_validity.AccountValidityConfig metrics: metrics.MetricsConfig api: api.ApiConfig appservice: appservice.AppServiceConfig @@ -81,6 +84,7 @@ class RootConfig: roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig + redis: redis.RedisConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py new file mode 100644
index 0000000000..6d107944a3 --- /dev/null +++ b/synapse/config/account_validity.py
@@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config._base import Config, ConfigError + + +class AccountValidityConfig(Config): + section = "account_validity" + + def read_config(self, config, **kwargs): + account_validity_config = config.get("account_validity") or {} + self.account_validity_enabled = account_validity_config.get("enabled", False) + self.account_validity_renew_by_email_enabled = ( + "renew_at" in account_validity_config + ) + + if self.account_validity_enabled: + if "period" in account_validity_config: + self.account_validity_period = self.parse_duration( + account_validity_config["period"] + ) + else: + raise ConfigError("'period' is required when using account validity") + + if "renew_at" in account_validity_config: + self.account_validity_renew_at = self.parse_duration( + account_validity_config["renew_at"] + ) + + if "renew_email_subject" in account_validity_config: + self.account_validity_renew_email_subject = account_validity_config[ + "renew_email_subject" + ] + else: + self.account_validity_renew_email_subject = "Renew your %(app)s account" + + self.account_validity_startup_job_max_delta = ( + self.account_validity_period * 10.0 / 100.0 + ) + + if self.account_validity_renew_by_email_enabled: + if not self.public_baseurl: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + # Load account validity templates. + # We do this here instead of in AccountValidityConfig as read_templates + # relies on state that hasn't been initialised in AccountValidityConfig + account_renewed_template_filename = account_validity_config.get( + "account_renewed_html_path", "account_renewed.html" + ) + account_previously_renewed_template_filename = account_validity_config.get( + "account_previously_renewed_html_path", "account_previously_renewed.html" + ) + invalid_token_template_filename = account_validity_config.get( + "invalid_token_html_path", "invalid_token.html" + ) + custom_template_directory = account_validity_config.get("template_dir") + + ( + self.account_validity_account_renewed_template, + self.account_validity_account_previously_renewed_template, + self.account_validity_invalid_token_template, + ) = self.read_templates( + [ + account_renewed_template_filename, + account_previously_renewed_template_filename, + invalid_token_template_filename, + ], + custom_template_directory=custom_template_directory, + ) + + def generate_config_section(self, **kwargs): + return """\ + ## Account Validity ## + # + # Optional account validity configuration. This allows for accounts to be denied + # any request after a given period. + # + # Once this feature is enabled, Synapse will look for registered users without an + # expiration date at startup and will add one to every account it found using the + # current settings at that time. + # This means that, if a validity period is set, and Synapse is restarted (it will + # then derive an expiration date from the current validity period), and some time + # after that the validity period changes and Synapse is restarted, the users' + # expiration dates won't be updated unless their account is manually renewed. This + # date will be randomly selected within a range [now + period - d ; now + period], + # where d is equal to 10% of the validity period. + # + account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + #template_dir: "res/templates" + + # File within 'template_dir' giving the HTML to be displayed to the user after + # they successfully renewed their account. If not set, default text is used. + # + #account_renewed_html_path: "account_renewed.html" + + # File within 'template_dir' giving the HTML to be displayed when the user + # tries to renew an account with an invalid renewal token. If not set, + # default text is used. + # + #invalid_token_html_path: "invalid_token.html" + """ diff --git a/synapse/config/api.py b/synapse/config/api.py
index 74cd53a8ed..0638ed8d2e 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py
@@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,27 +17,31 @@ from synapse.api.constants import EventTypes from ._base import Config +# The default types of room state to send to users to are invited to or knock on a room. +DEFAULT_ROOM_STATE_TYPES = [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, +] + class ApiConfig(Config): section = "api" def read_config(self, config, **kwargs): self.room_invite_state_types = config.get( - "room_invite_state_types", - [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ], + "room_invite_state_types", DEFAULT_ROOM_STATE_TYPES ) def generate_config_section(cls, **kwargs): return """\ ## API Configuration ## - # A list of event types that will be included in the room_invite_state + # A list of event types from a room that will be given to users when they + # are invited to a room. This allows clients to display information about the + # room that they've been invited to, without actually being in the room yet. # #room_invite_state_types: # - "{JoinRules}" diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index cb00958165..9e48f865cc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py
@@ -28,9 +28,7 @@ class CaptchaConfig(Config): "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", ) - self.recaptcha_template = self.read_templates( - ["recaptcha.html"], autoescape=True - )[0] + self.recaptcha_template = self.read_template("recaptcha.html") def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index c7877b4095..b226890c2a 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py
@@ -30,7 +30,13 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - self.cas_service_url = cas_config["service_url"] + public_base_url = cas_config.get("service_url") or self.public_baseurl + if public_base_url[-1] != "/": + public_base_url += "/" + # TODO Update this to a _synapse URL. + self.cas_service_url = ( + public_base_url + "_matrix/client/r0/login/cas/ticket" + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: @@ -53,10 +59,6 @@ class CasConfig(Config): # #server_url: "https://cas-server.com" - # The public URL of the homeserver. - # - #service_url: "https://homeserver.domain.com:8448" - # The attribute of the CAS response to use as the display name. # # If unset, no displayname will be set. diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 6efa59b110..c47f364b14 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py
@@ -89,7 +89,7 @@ class ConsentConfig(Config): def read_config(self, config, **kwargs): consent_config = config.get("user_consent") - self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0] + self.terms_template = self.read_template("terms.html") if consent_config is None: return diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 6a487afd34..458f5eb0da 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py
@@ -291,7 +291,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity.renew_by_email_enabled: + if self.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b1c1c51e4d..ba9d37553b 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config._base import Config from synapse.types import JsonDict @@ -25,5 +26,11 @@ class ExperimentalConfig(Config): def read_config(self, config: JsonDict, **kwargs): experimental = config.get("experimental_features") or {} + # MSC2403 (room knocking) + self.msc2403_enabled = experimental.get("msc2403_enabled", False) # type: bool + if self.msc2403_enabled: + # Enable the MSC2403 unstable room version + KNOWN_ROOM_VERSIONS.update({RoomVersions.V7.identifier: RoomVersions.V7}) + # MSC2858 (multiple SSO identity providers) self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 64a2429f77..58961679ff 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ._base import RootConfig +from .account_validity import AccountValidityConfig from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig @@ -69,6 +69,7 @@ class HomeServerConfig(RootConfig): CaptchaConfig, VoipConfig, RegistrationConfig, + AccountValidityConfig, MetricsConfig, ApiConfig, AppServiceConfig, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 784b416f95..bb122ef182 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py
@@ -53,8 +53,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl - self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..070eb1b761 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig: defaults={"per_second": 0.17, "burst_count": 3.0}, ): self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = config.get("burst_count", defaults["burst_count"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) class FederationRateLimitConfig: @@ -76,6 +76,9 @@ class RatelimitConfig(Config): ) self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) + self.rc_third_party_invite = RateLimitConfig( + config.get("rc_third_party_invite", {}) + ) rc_login_config = config.get("rc_login", {}) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) @@ -102,6 +105,20 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + self.rc_3pid_validation = RateLimitConfig( + config.get("rc_3pid_validation") or {}, + defaults={"per_second": 0.003, "burst_count": 5}, + ) + + self.rc_invites_per_room = RateLimitConfig( + config.get("rc_invites", {}).get("per_room", {}), + defaults={"per_second": 0.3, "burst_count": 10}, + ) + self.rc_invites_per_user = RateLimitConfig( + config.get("rc_invites", {}).get("per_user", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -124,6 +141,8 @@ class RatelimitConfig(Config): # - one for login that ratelimits login requests based on the account the # client is attempting to log into, based on the amount of failed login # attempts for this account. + # - one that ratelimits third-party invites requests based on the account + # that's making the requests. # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. @@ -131,6 +150,9 @@ class RatelimitConfig(Config): # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) + # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. + # - two for ratelimiting how often invites can be sent in a room or to a + # specific user. # # The defaults are as shown below. # @@ -153,6 +175,10 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 # + #rc_third_party_invite: + # per_second: 0.2 + # burst_count: 10 + # #rc_admin_redaction: # per_second: 1 # burst_count: 50 @@ -164,7 +190,18 @@ class RatelimitConfig(Config): # remote: # per_second: 0.01 # burst_count: 3 - + # + #rc_3pid_validation: + # per_second: 0.003 + # burst_count: 5 + # + #rc_invites: + # per_room: + # per_second: 0.3 + # burst_count: 10 + # per_user: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 4bfc69cb7a..c96530d4e3 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -13,70 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -import pkg_resources - from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool -class AccountValidityConfig(Config): - section = "accountvalidity" - - def __init__(self, config, synapse_config): - if config is None: - return - super().__init__() - self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = "renew_at" in config - - if self.enabled: - if "period" in config: - self.period = self.parse_duration(config["period"]) - else: - raise ConfigError("'period' is required when using account validity") - - if "renew_at" in config: - self.renew_at = self.parse_duration(config["renew_at"]) - - if "renew_email_subject" in config: - self.renew_email_subject = config["renew_email_subject"] - else: - self.renew_email_subject = "Renew your %(app)s account" - - self.startup_job_max_delta = self.period * 10.0 / 100.0 - - template_dir = config.get("template_dir") - - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - if "account_renewed_html_path" in config: - file_path = os.path.join(template_dir, config["account_renewed_html_path"]) - - self.account_renewed_html_content = self.read_file( - file_path, "account_validity.account_renewed_html_path" - ) - else: - self.account_renewed_html_content = ( - "<html><body>Your account has been successfully renewed.</body><html>" - ) - - if "invalid_token_html_path" in config: - file_path = os.path.join(template_dir, config["invalid_token_html_path"]) - - self.invalid_token_html_content = self.read_file( - file_path, "account_validity.invalid_token_html_path" - ) - else: - self.invalid_token_html_content = ( - "<html><body>Invalid renewal token.</body><html>" - ) - - class RegistrationConfig(Config): section = "registration" @@ -89,14 +31,21 @@ class RegistrationConfig(Config): str(config["disable_registration"]) ) - self.account_validity = AccountValidityConfig( - config.get("account_validity") or {}, config - ) - self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) + self.check_is_for_allowed_local_3pids = config.get( + "check_is_for_allowed_local_3pids", None + ) + self.allow_invited_3pids = config.get("allow_invited_3pids", False) + + self.disable_3pid_changes = config.get("disable_3pid_changes", False) + self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) self.registration_shared_secret = config.get("registration_shared_secret") + self.register_mxid_from_3pid = config.get("register_mxid_from_3pid") + self.register_just_use_email_for_display_name = config.get( + "register_just_use_email_for_display_name", False + ) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config.get( @@ -104,7 +53,21 @@ class RegistrationConfig(Config): ) account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") + if ( + self.account_threepid_delegate_email + and not self.account_threepid_delegate_email.startswith("http") + ): + raise ConfigError( + "account_threepid_delegates.email must begin with http:// or https://" + ) self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if ( + self.account_threepid_delegate_msisdn + and not self.account_threepid_delegate_msisdn.startswith("http") + ): + raise ConfigError( + "account_threepid_delegates.msisdn must begin with http:// or https://" + ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -166,6 +129,15 @@ class RegistrationConfig(Config): self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) self.enable_3pid_changes = config.get("enable_3pid_changes", True) + self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", []) + if not isinstance(self.replicate_user_profiles_to, list): + self.replicate_user_profiles_to = [self.replicate_user_profiles_to] + + self.shadow_server = config.get("shadow_server", None) + self.rewrite_identity_server_urls = ( + config.get("rewrite_identity_server_urls") or {} + ) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -176,9 +148,24 @@ class RegistrationConfig(Config): self.session_lifetime = session_lifetime # The success template used during fallback auth. - self.fallback_success_template = self.read_templates( - ["auth_success.html"], autoescape=True - )[0] + self.fallback_success_template = self.read_template("auth_success.html") + + self.bind_new_user_emails_to_sydent = config.get( + "bind_new_user_emails_to_sydent" + ) + + if self.bind_new_user_emails_to_sydent: + if not isinstance( + self.bind_new_user_emails_to_sydent, str + ) or not self.bind_new_user_emails_to_sydent.startswith("http"): + raise ConfigError( + "Option bind_new_user_emails_to_sydent has invalid value" + ) + + # Remove trailing slashes + self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip( + "/" + ) def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: @@ -199,70 +186,6 @@ class RegistrationConfig(Config): # #enable_registration: false - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10%% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' - # configuration section. You should also check that 'public_baseurl' is set - # correctly. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %%(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - #template_dir: "res/templates" - - # File within 'template_dir' giving the HTML to be displayed to the user after - # they successfully renewed their account. If not set, default text is used. - # - #account_renewed_html_path: "account_renewed.html" - - # File within 'template_dir' giving the HTML to be displayed when the user - # tries to renew an account with an invalid renewal token. If not set, - # default text is used. - # - #invalid_token_html_path: "invalid_token.html" - # Time that a user's session remains valid for, after they log in. # # Note that this is not currently compatible with guest logins. @@ -285,9 +208,32 @@ class RegistrationConfig(Config): # #disable_msisdn_registration: true + # Derive the user's matrix ID from a type of 3PID used when registering. + # This overrides any matrix ID the user proposes when calling /register + # The 3PID type should be present in registrations_require_3pid to avoid + # users failing to register if they don't specify the right kind of 3pid. + # + #register_mxid_from_3pid: email + + # Uncomment to set the display name of new users to their email address, + # rather than using the default heuristic. + # + #register_just_use_email_for_display_name: true + # Mandate that users are only allowed to associate certain formats of # 3PIDs with accounts on this server. # + # Use an Identity Server to establish which 3PIDs are allowed to register? + # Overrides allowed_local_3pids below. + # + #check_is_for_allowed_local_3pids: matrix.org + # + # If you are using an IS you can also check whether that IS registers + # pending invites for the given 3PID (and then allow it to sign up on + # the platform): + # + #allow_invited_3pids: false + # #allowed_local_3pids: # - medium: email # pattern: '.*@matrix\\.org' @@ -296,6 +242,11 @@ class RegistrationConfig(Config): # - medium: msisdn # pattern: '\\+44' + # If true, stop users from trying to change the 3PIDs associated with + # their accounts. + # + #disable_3pid_changes: false + # Enable 3PIDs lookup requests to identity servers from this server. # #enable_3pid_lookup: true @@ -326,6 +277,30 @@ class RegistrationConfig(Config): # #default_identity_server: https://matrix.org + # If enabled, user IDs, display names and avatar URLs will be replicated + # to this server whenever they change. + # This is an experimental API currently implemented by sydent to support + # cross-homeserver user directories. + # + #replicate_user_profiles_to: example.com + + # If specified, attempt to replay registrations, profile changes & 3pid + # bindings on the given target homeserver via the AS API. The HS is authed + # via a given AS token. + # + #shadow_server: + # hs_url: https://shadow.example.com + # hs: shadow.example.com + # as_token: 12u394refgbdhivsia + + # If enabled, don't let users set their own display names/avatars + # other than for the very first time (unless they are a server admin). + # Useful when provisioning users based on the contents of a 3rd party + # directory and to avoid ambiguities. + # + #disable_set_displayname: false + #disable_set_avatar_url: false + # Handle threepid (email/phone etc) registration and password resets through a set of # *trusted* identity servers. Note that this allows the configured identity server to # reset passwords for accounts! @@ -450,6 +425,31 @@ class RegistrationConfig(Config): # Defaults to true. # #auto_join_rooms_for_guests: false + + # Rewrite identity server URLs with a map from one URL to another. Applies to URLs + # provided by clients (which have https:// prepended) and those specified + # in `account_threepid_delegates`. URLs should not feature a trailing slash. + # + #rewrite_identity_server_urls: + # "https://somewhere.example.com": "https://somewhereelse.example.com" + + # When a user registers an account with an email address, it can be useful to + # bind that email address to their mxid on an identity server. Typically, this + # requires the user to validate their email address with the identity server. + # However if Synapse itself is handling email validation on registration, the + # user ends up needing to validate their email twice, which leads to poor UX. + # + # It is possible to force Sydent, one identity server implementation, to bind + # threepids using its internal, unauthenticated bind API: + # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api + # + # Configure the address of a Sydent server here to have Synapse attempt + # to automatically bind users' emails following registration. The + # internal bind API must be reachable from Synapse, but should NOT be + # exposed to any third party, as it allows the creation of bindings + # without validation. + # + #bind_new_user_emails_to_sydent: https://example.com:8091 """ % locals() ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 850ac3ebd6..31e3f7148b 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -107,6 +107,12 @@ class ContentRepositoryConfig(Config): self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) + self.max_avatar_size = config.get("max_avatar_size") + if self.max_avatar_size: + self.max_avatar_size = self.parse_size(self.max_avatar_size) + + self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", []) + self.media_store_path = self.ensure_directory( config.get("media_store_path", "media_store") ) @@ -250,6 +256,30 @@ class ContentRepositoryConfig(Config): # #max_upload_size: 50M + # The largest allowed size for a user avatar. If not defined, no + # restriction will be imposed. + # + # Note that this only applies when an avatar is changed globally. + # Per-room avatar changes are not affected. See allow_per_room_profiles + # for disabling that functionality. + # + # Note that user avatar changes will not work if this is set without + # using Synapse's local media repo. + # + #max_avatar_size: 10M + + # Allow mimetypes for a user avatar. If not defined, no restriction will + # be imposed. + # + # Note that this only applies when an avatar is changed globally. + # Per-room avatar changes are not affected. See allow_per_room_profiles + # for disabling that functionality. + # + # Note that user avatar changes will not work if this is set without + # using Synapse's local media repo. + # + #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"] + # Maximum number of pixels that will be thumbnailed # #max_image_pixels: 32M diff --git a/synapse/config/server.py b/synapse/config/server.py
index 47a0370173..b76afce5e5 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -338,6 +338,12 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + # Whether to show the users on this homeserver in the user directory. Defaults to + # True. + self.show_users_in_user_directory = config.get( + "show_users_in_user_directory", True + ) + retention_config = config.get("retention") if retention_config is None: retention_config = {} @@ -1047,6 +1053,74 @@ class ServerConfig(Config): # #allow_per_room_profiles: false + # Whether to show the users on this homeserver in the user directory. Defaults to + # 'true'. + # + #show_users_in_user_directory: false + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h + # How long to keep redacted events in unredacted form in the database. After # this period redacted events get replaced with their redacted form in the DB. # diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..306e0cc8a4 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py
@@ -26,6 +26,8 @@ class UserDirectoryConfig(Config): def read_config(self, config, **kwargs): self.user_directory_search_enabled = True self.user_directory_search_all_users = False + self.user_directory_defer_to_id_server = None + self.user_directory_search_prefer_local_users = False user_directory_config = config.get("user_directory", None) if user_directory_config: self.user_directory_search_enabled = user_directory_config.get( @@ -34,6 +36,12 @@ class UserDirectoryConfig(Config): self.user_directory_search_all_users = user_directory_config.get( "search_all_users", False ) + self.user_directory_defer_to_id_server = user_directory_config.get( + "defer_to_id_server", None + ) + self.user_directory_search_prefer_local_users = user_directory_config.get( + "prefer_local_users", False + ) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ @@ -49,7 +57,17 @@ class UserDirectoryConfig(Config): # rebuild the user_directory search indexes, see # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md # + # 'prefer_local_users' defines whether to prioritise local users in + # search query results. If True, local users are more likely to appear above + # remote users when searching the user directory. Defaults to false. + # #user_directory: # enabled: true # search_all_users: false + # prefer_local_users: false + # + # # If this is set, user search will be delegated to this ID server instead + # # of synapse performing the search itself. + # # This is an experimental API. + # defer_to_id_server: https://id.example.com """ diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 74b67b230a..14b21796d9 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py
@@ -125,19 +125,24 @@ class FederationPolicyForHTTPS: self._no_verify_ssl_context = _no_verify_ssl.getContext() self._no_verify_ssl_context.set_info_callback(_context_info_cb) - def get_options(self, host: bytes): + self._should_verify = self._config.federation_verify_certificates + + self._federation_certificate_verification_whitelist = ( + self._config.federation_certificate_verification_whitelist + ) + def get_options(self, host: bytes): # IPolicyForHTTPS.get_options takes bytes, but we want to compare # against the str whitelist. The hostnames in the whitelist are already # IDNA-encoded like the hosts will be here. ascii_host = host.decode("ascii") # Check if certificate verification has been enabled - should_verify = self._config.federation_verify_certificates + should_verify = self._should_verify # Check if we've disabled certificate verification for this host - if should_verify: - for regex in self._config.federation_certificate_verification_whitelist: + if self._should_verify: + for regex in self._federation_certificate_verification_whitelist: if regex.match(ascii_host): should_verify = False break diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 56f8dc9caf..498a699290 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -161,6 +161,7 @@ def check( if logger.isEnabledFor(logging.DEBUG): logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) + # 5. If type is m.room.membership if event.type == EventTypes.Member: _is_membership_change_allowed(event, auth_events) logger.debug("Allowing! %s", event) @@ -247,6 +248,7 @@ def _is_membership_change_allowed( caller_in_room = caller and caller.membership == Membership.JOIN caller_invited = caller and caller.membership == Membership.INVITE + caller_knocked = caller and caller.membership == Membership.KNOCK # get info about the target key = (EventTypes.Member, target_user_id) @@ -289,9 +291,12 @@ def _is_membership_change_allowed( raise AuthError(403, "%s is banned from the room" % (target_user_id,)) return - if Membership.JOIN != membership: + # Require the user to be in the room for membership changes other than join/knock. + if Membership.JOIN != membership and Membership.KNOCK != membership: + # If the user has been invited or has knocked, they are allowed to change their + # membership event to leave if ( - caller_invited + (caller_invited or caller_knocked) and Membership.LEAVE == membership and target_user_id == event.user_id ): @@ -324,7 +329,7 @@ def _is_membership_change_allowed( raise AuthError(403, "You are banned from this room") elif join_rule == JoinRules.PUBLIC: pass - elif join_rule == JoinRules.INVITE: + elif join_rule in (JoinRules.INVITE, JoinRules.KNOCK): if not caller_in_room and not caller_invited: raise AuthError(403, "You are not invited to this room.") else: @@ -343,6 +348,17 @@ def _is_membership_change_allowed( elif Membership.BAN == membership: if user_level < ban_level or user_level <= target_level: raise AuthError(403, "You don't have permission to ban") + elif Membership.KNOCK == membership: + if join_rule != JoinRules.KNOCK: + raise AuthError(403, "You don't have permission to knock") + elif target_user_id != event.user_id: + raise AuthError(403, "You cannot knock for other users") + elif target_in_room: + raise AuthError(403, "You cannot knock on a room you are already in") + elif caller_invited: + raise AuthError(403, "You are already invited to this room") + elif target_banned: + raise AuthError(403, "You are banned from this room") else: raise AuthError(500, "Unknown membership %s" % membership) @@ -699,7 +715,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]: if event.type == EventTypes.Member: membership = event.content["membership"] - if membership in [Membership.JOIN, Membership.INVITE]: + if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]: auth_types.add((EventTypes.JoinRules, "")) auth_types.add((EventTypes.Member, event.state_key)) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..f3322af499 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py
@@ -63,16 +63,32 @@ class SpamChecker: return False async def user_may_invite( - self, inviter_userid: str, invitee_userid: str, room_id: str + self, + inviter_userid: str, + invitee_userid: Optional[str], + third_party_invite: Optional[Dict], + room_id: str, + new_room: bool, + published_room: bool, ) -> bool: """Checks if a given user may send an invite If this method returns false, the invite will be rejected. Args: - inviter_userid: The user ID of the sender of the invitation - invitee_userid: The user ID targeted in the invitation - room_id: The room ID + inviter_userid: + invitee_userid: The user ID of the invitee. Is None + if this is a third party invite and the 3PID is not bound to a + user ID. + third_party_invite: If a third party invite then is a + dict containing the medium and address of the invitee. + room_id: + new_room: Whether the user is being invited to the room as + part of a room creation, if so the invitee would have been + included in the call to `user_may_create_room`. + published_room: Whether the room the user is being invited + to has been published in the local homeserver's public room + directory. Returns: True if the user may send an invite, otherwise False @@ -81,7 +97,12 @@ class SpamChecker: if ( await maybe_awaitable( spam_checker.user_may_invite( - inviter_userid, invitee_userid, room_id + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, + published_room, ) ) is False @@ -90,20 +111,36 @@ class SpamChecker: return True - async def user_may_create_room(self, userid: str) -> bool: + async def user_may_create_room( + self, + userid: str, + invite_list: List[str], + third_party_invite_list: List[Dict], + cloning: bool, + ) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. Args: userid: The ID of the user attempting to create a room + invite_list: List of user IDs that would be invited to + the new room. + third_party_invite_list: List of third party invites + for the new room. + cloning: Whether the user is cloning an existing room, e.g. + upgrading a room. Returns: True if the user may create a room, otherwise False """ for spam_checker in self.spam_checkers: if ( - await maybe_awaitable(spam_checker.user_may_create_room(userid)) + await maybe_awaitable( + spam_checker.user_may_create_room( + userid, invite_list, third_party_invite_list, cloning + ) + ) is False ): return False @@ -156,6 +193,25 @@ class SpamChecker: return True + def user_may_join_room(self, userid: str, room_id: str, is_invited: bool): + """Checks if a given users is allowed to join a room. + + Not called when a user creates a room. + + Args: + userid: + room_id: + is_invited: Whether the user is invited into the room + + Returns: + bool: Whether the user may join the room + """ + for spam_checker in self.spam_checkers: + if spam_checker.user_may_join_room(userid, room_id, is_invited) is False: + return False + + return True + async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9c22e33813..4eadda4b40 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -240,6 +240,7 @@ def format_event_for_client_v1(d): "replaces_state", "prev_content", "invite_room_state", + "knock_room_state", ) for key in copy_keys: if key in d["unsigned"]: @@ -276,7 +277,7 @@ def serialize_event( event_format=format_event_for_client_v1, token_id=None, only_event_fields=None, - is_invite=False, + include_stripped_room_state=False, ): """Serialize event for clients @@ -287,8 +288,10 @@ def serialize_event( event_format token_id only_event_fields - is_invite (bool): Whether this is an invite that is being sent to the - invitee + include_stripped_room_state (bool): Some events can have stripped room state + stored in the `unsigned` field. This is required for invite and knock + functionality. If this option is False, that state will be removed from the + event before it is returned. Otherwise, it will be kept. Returns: dict @@ -320,11 +323,13 @@ def serialize_event( if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id - # If this is an invite for somebody else, then we don't care about the - # invite_room_state as that's meant solely for the invitee. Other clients - # will already have the state since they're in the room. - if not is_invite: + # invite_room_state and knock_room_state are a list of stripped room state events + # that are meant to provide metadata about a room to an invitee/knocker. They are + # intended to only be included in specific circumstances, such as down sync, and + # should not be included in any other case. + if not include_stripped_room_state: d["unsigned"].pop("invite_room_state", None) + d["unsigned"].pop("knock_room_state", None) if as_client_event: d = event_format(d) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d330ae5dbc..22bb1afd68 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyrignt 2020 Sorunome +# Copyrignt 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -553,7 +555,7 @@ class FederationClient(FederationBase): RuntimeError: if no servers were reachable. """ - valid_memberships = {Membership.JOIN, Membership.LEAVE} + valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" @@ -810,7 +812,7 @@ class FederationClient(FederationBase): "User's homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - elif e.code == 403: + elif e.code in (403, 429): raise e.to_synapse_error() else: raise @@ -888,6 +890,62 @@ class FederationClient(FederationBase): # content. return resp[1] + async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict: + """Attempts to send a knock event to given a list of servers. Iterates + through the list until one attempt succeeds. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + Args: + destinations: A list of candidate homeservers which are likely to be + participating in the room. + pdu: The event to be sent. + + Returns: + The remote homeserver return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + + Raises: + SynapseError: If the chosen remote server returns a 3xx/4xx code. + RuntimeError: If no servers were reachable. + """ + + async def send_request(destination: str) -> JsonDict: + return await self._do_send_knock(destination, pdu) + + return await self._try_destination_list( + "xyz.amorgan.knock/send_knock", destinations, send_request + ) + + async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict: + """Send a knock event to a remote homeserver. + + Args: + destination: The homeserver to send to. + pdu: The event to send. + + Returns: + The remote homeserver can optionally return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + """ + time_now = self._clock.time_msec() + + return await self.transport_layer.send_knock_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + async def get_public_rooms( self, remote_server: str, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 171d25c945..c5e57b9d11 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.config.api import DEFAULT_ROOM_STATE_TYPES from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions @@ -567,6 +568,76 @@ class FederationServer(FederationBase): await self.handler.on_send_leave_request(origin, pdu) return {} + async def on_make_knock_request( + self, origin: str, room_id: str, user_id: str, supported_versions: List[str] + ) -> Dict[str, Union[EventBase, str]]: + """We've received a /make_knock/ request, so we create a partial knock + event for the room and hand that back, along with the room version, to the knocking + homeserver. We do *not* persist or process this event until the other server has + signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: The room to create the knock event in. + user_id: The user to create the knock for. + supported_versions: The room versions supported by the requesting server. + + Returns: + The partial knock event. + """ + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + room_version = await self.store.get_room_version_id(room_id) + if room_version not in supported_versions: + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) + raise IncompatibleRoomVersionError(room_version=room_version) + + pdu = await self.handler.on_make_knock_request(origin, room_id, user_id) + time_now = self._clock.time_msec() + return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + + async def on_send_knock_request( + self, origin: str, content: JsonDict, room_id: str, + ) -> Dict[str, List[JsonDict]]: + """ + We have received a knock event for a room. Verify and send the event into the room + on the knocking homeserver's behalf. Then reply with some stripped state from the + room for the knockee. + + Args: + origin: The remote homeserver of the knocking user. + content: The content of the request. + room_id: The ID of the room to knock on. + + Returns: + The stripped room state. + """ + logger.debug("on_send_knock_request: content: %s", content) + + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) + + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, pdu.room_id) + + logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + # Handle the event, and retrieve the EventContext + event_context = await self.handler.on_send_knock_request(origin, pdu) + + # Retrieve stripped state events from the room and send them back to the remote + # server. This will allow the remote server's clients to display information + # related to the room while the knock request is pending. + stripped_room_state = await self.store.get_stripped_room_state_from_event_context( + event_context, DEFAULT_ROOM_STATE_TYPES + ) + return {"knock_state_events": stripped_room_state} + async def on_event_auth( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender: self._wake_destinations_needing_catchup, ) + self._external_cache = hs.get_external_cache() + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -197,22 +199,40 @@ class FederationSender: if not event.internal_metadata.should_proactively_send(): return - try: - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() - ) - except Exception: - logger.exception( - "Failed to calculate hosts in room for event: %s", - event.event_id, + destinations = None # type: Optional[Set[str]] + if not event.prev_event_ids(): + # If there are no prev event IDs then the state is empty + # and so no remote servers in the room + destinations = set() + else: + # We check the external cache for the destinations, which is + # stored per state group. + + sg = await self._external_cache.get( + "event_to_prev_state_group", event.event_id ) - return + if sg: + destinations = await self._external_cache.get( + "get_joined_hosts", str(sg) + ) + + if destinations is None: + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return destinations = { d diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index abe9168c78..9c454e5885 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +18,7 @@ import logging import urllib -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError @@ -26,6 +28,7 @@ from synapse.api.urls import ( FEDERATION_V2_PREFIX, ) from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -209,13 +212,24 @@ class TransportLayerClient: Fails with ``FederationDeniedError`` if the remote destination is not in our federation whitelist """ - valid_memberships = {Membership.JOIN, Membership.LEAVE} + valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) + + # Knock currently uses an unstable prefix + if membership == Membership.KNOCK: + # Create a path in the form of /unstable/xyz.amorgan.knock/make_knock/... + path = _create_path( + FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock", + "/make_knock/%s/%s", + room_id, + user_id, + ) + else: + path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) ignore_backoff = False retry_on_dns_fail = False @@ -294,6 +308,41 @@ class TransportLayerClient: return response @log_function + async def send_knock_v2( + self, destination: str, room_id: str, event_id: str, content: JsonDict, + ) -> JsonDict: + """ + Sends a signed knock membership event to a remote server. This is the second + step for knocking after make_knock. + + Args: + destination: The remote homeserver. + room_id: The ID of the room to knock on. + event_id: The ID of the knock membership event that we're sending. + content: The knock membership event that we're sending. Note that this is not the + `content` field of the membership event, but the entire signed membership event + itself represented as a JSON dict. + + Returns: + The remote homeserver can optionally return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock", + "/send_knock/%s/%s", + room_id, + event_id, + ) + + return await self.client.put_json( + destination=destination, path=path, data=content + ) + + @log_function async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) @@ -1004,6 +1053,20 @@ class TransportLayerClient: return self.client.get_json(destination=destination, path=path) + def get_info_of_users(self, destination: str, user_ids: List[str]): + """ + Args: + destination: The remote server + user_ids: A list of user IDs to query info about + + Returns: + Deferred[List]: A dictionary of User ID to information about that user. + """ + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info") + data = {"user_ids": user_ids} + + return self.client.post_json(destination=destination, path=path, data=data) + def _create_path(federation_prefix, path, *args): """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 95c64510a9..e9fb8d4079 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +15,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import functools import logging import re @@ -30,9 +30,11 @@ from synapse.api.urls import ( ) from synapse.http.server import JsonResource from synapse.http.servlet import ( + assert_params_in_dict, parse_boolean_from_args, parse_integer_from_args, parse_json_object_from_request, + parse_list_from_args, parse_string_from_args, ) from synapse.logging.context import run_in_background @@ -542,6 +544,34 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): return 200, content +class FederationMakeKnockServlet(BaseFederationServlet): + PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" + + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock" + + async def on_GET(self, origin, content, query, room_id, user_id): + try: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_list_from_args(query, "ver", encoding="utf-8") + except KeyError: + raise SynapseError(400, "Missing required query parameter 'ver'") + + content = await self.handler.on_make_knock_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, content + + +class FederationV2SendKnockServlet(BaseFederationServlet): + PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock" + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, content + + class FederationEventAuthServlet(BaseFederationServlet): PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" @@ -842,6 +872,57 @@ class PublicRoomList(BaseFederationServlet): return 200, data +class FederationUserInfoServlet(BaseFederationServlet): + """ + Return information about a set of users. + + This API returns expiration and deactivation information about a set of + users. Requested users not local to this homeserver will be ignored. + + Example request: + POST /users/info + + { + "user_ids": [ + "@alice:example.com", + "@bob:example.com" + ] + } + + Example response + { + "@alice:example.com": { + "expired": false, + "deactivated": true + } + } + """ + + PATH = "/users/info" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + def __init__(self, handler, authenticator, ratelimiter, server_name): + super(FederationUserInfoServlet, self).__init__( + handler, authenticator, ratelimiter, server_name + ) + self.handler = handler + + async def on_POST(self, origin, content, query): + assert_params_in_dict(content, required=["user_ids"]) + + user_ids = content.get("user_ids", []) + + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + data = await self.handler.store.get_info_for_users(user_ids) + return 200, data + + class FederationVersionServlet(BaseFederationServlet): PATH = "/version" @@ -1387,11 +1468,13 @@ FEDERATION_SERVLET_CLASSES = ( FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, + FederationMakeKnockServlet, FederationEventServlet, FederationV1SendJoinServlet, FederationV2SendJoinServlet, FederationV1SendLeaveServlet, FederationV2SendLeaveServlet, + FederationV2SendKnockServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationGetMissingEventsServlet, @@ -1403,6 +1486,7 @@ FEDERATION_SERVLET_CLASSES = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, + FederationUserInfoServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..ce97fa70d7 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -18,11 +18,14 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import UserID from synapse.util import stringutils @@ -40,27 +43,37 @@ class AccountValidityHandler: self.sendmail = self.hs.get_sendmail() self.clock = self.hs.get_clock() - self._account_validity = self.hs.config.account_validity + self._account_validity_enabled = self.hs.config.account_validity_enabled + self._account_validity_renew_by_email_enabled = ( + self.hs.config.account_validity_renew_by_email_enabled + ) + self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory + self.profile_handler = self.hs.get_profile_handler() + + self._account_validity_period = None + if self._account_validity_enabled: + self._account_validity_period = self.hs.config.account_validity_period if ( - self._account_validity.enabled - and self._account_validity.renew_by_email_enabled + self._account_validity_enabled + and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. self._template_html = self.config.account_validity_template_html self._template_text = self.config.account_validity_template_text + account_validity_renew_email_subject = ( + self.hs.config.account_validity_renew_email_subject + ) try: app_name = self.hs.config.email_app_name - self._subject = self._account_validity.renew_email_subject % { - "app": app_name - } + self._subject = account_validity_renew_email_subject % {"app": app_name} self._from_string = self.hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. - self._subject = self._account_validity.renew_email_subject + self._subject = account_validity_renew_email_subject self._from_string = self.hs.config.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] @@ -69,6 +82,18 @@ class AccountValidityHandler: if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + # Mark users as inactive when they expired. Check once every hour + if self._account_validity_enabled: + + def mark_expired_users_as_inactive(): + # run as a background process to allow async functions to work + return run_as_background_process( + "_mark_expired_users_as_inactive", + self._mark_expired_users_as_inactive, + ) + + self.clock.looping_call(mark_expired_users_as_inactive, 60 * 60 * 1000) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time @@ -221,47 +246,107 @@ class AccountValidityHandler: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - async def renew_account(self, renewal_token: str) -> bool: + async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. + If it turns out that the token is valid but has already been used, then the + token is considered stale. A token is stale if the 'token_used_ts_ms' db column + is non-null. + Args: renewal_token: Token sent with the renewal request. Returns: - Whether the provided token is valid. + A tuple containing: + * A bool representing whether the token is valid and unused. + * A bool representing whether the token is stale. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. """ try: - user_id = await self.store.get_user_from_renewal_token(renewal_token) + ( + user_id, + current_expiration_ts, + token_used_ts, + ) = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - return False + return False, False, 0 + + # Check whether this token has already been used. + if token_used_ts: + logger.info( + "User '%s' attempted to use previously used token '%s' to renew account", + user_id, + renewal_token, + ) + return False, True, current_expiration_ts logger.debug("Renewing an account for user %s", user_id) - await self.renew_account_for_user(user_id) - return True + # Renew the account. Pass the renewal_token here so that it is not cleared. + # We want to keep the token around in case the user attempts to renew their + # account with the same token twice (clicking the email link twice). + # + # In that case, the token will be accepted, but the account's expiration ts + # will remain unchanged. + new_expiration_ts = await self.renew_account_for_user( + user_id, renewal_token=renewal_token + ) + + return True, False, new_expiration_ts async def renew_account_for_user( - self, user_id: str, expiration_ts: int = None, email_sent: bool = False + self, + user_id: str, + expiration_ts: Optional[int] = None, + email_sent: bool = False, + renewal_token: Optional[str] = None, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token: Token sent with the renewal request. + user_id: The ID of the user to renew. expiration_ts: New expiration date. Defaults to now + validity period. - email_sen: Whether an email has been sent for this validity period. - Defaults to False. + email_sent: Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. The user's token + will be cleared if this is None. Returns: New expiration date for this account, as a timestamp in milliseconds since epoch. """ + now = self.clock.time_msec() if expiration_ts is None: - expiration_ts = self.clock.time_msec() + self._account_validity.period + expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( - user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent + user_id=user_id, + expiration_ts=expiration_ts, + email_sent=email_sent, + renewal_token=renewal_token, + token_used_ts=now, ) + # Check if renewed users should be reintroduced to the user directory + if self._show_users_in_user_directory: + # Show the user in the directory again by setting them to active + await self.profile_handler.set_active( + [UserID.from_string(user_id)], True, True + ) + return expiration_ts + + async def _mark_expired_users_as_inactive(self): + """Iterate over active, expired users. Mark them as inactive in order to hide them + from the user directory. + + Returns: + Deferred + """ + # Get active, expired users + active_expired_users = await self.store.get_expired_users() + + # Mark each as non-active + await self.profile_handler.set_active(active_expired_users, False, True) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 8476256a59..5ecb2da1ac 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py
@@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING import twisted import twisted.internet.error @@ -22,6 +23,9 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) ACME_REGISTER_FAIL_ERROR = """ @@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC class AcmeHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.reactor = hs.get_reactor() self._acme_domain = hs.config.acme_domain - async def start_listening(self): + async def start_listening(self) -> None: from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug @@ -85,7 +89,7 @@ class AcmeHandler: logger.error(ACME_REGISTER_FAIL_ERROR) raise - async def provision_certificate(self): + async def provision_certificate(self) -> None: logger.warning("Reprovisioning %s", self._acme_domain) @@ -110,5 +114,3 @@ class AcmeHandler: except Exception: logger.exception("Failed saving!") raise - - return True diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
index 7294649d71..ae2a9dd9c2 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to imported conditionally. """ import logging +from typing import Dict, Iterable, List import attr +import pem from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from josepy import JWKRSA @@ -36,20 +38,27 @@ from txacme.util import generate_private_key from zope.interface import implementer from twisted.internet import defer +from twisted.internet.interfaces import IReactorTCP from twisted.python.filepath import FilePath from twisted.python.url import URL +from twisted.web.resource import IResource logger = logging.getLogger(__name__) -def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): +def create_issuing_service( + reactor: IReactorTCP, + acme_url: str, + account_key_file: str, + well_known_resource: IResource, +) -> AcmeIssuingService: """Create an ACME issuing service, and attach it to a web Resource Args: reactor: twisted reactor - acme_url (str): URL to use to request certificates - account_key_file (str): where to store the account key - well_known_resource (twisted.web.IResource): web resource for .well-known. + acme_url: URL to use to request certificates + account_key_file: where to store the account key + well_known_resource: web resource for .well-known. we will attach a child resource for "acme-challenge". Returns: @@ -83,18 +92,20 @@ class ErsatzStore: A store that only stores in memory. """ - certs = attr.ib(default=attr.Factory(dict)) + certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) - def store(self, server_name, pem_objects): + def store( + self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] + ) -> defer.Deferred: self.certs[server_name] = [o.as_bytes() for o in pem_objects] return defer.succeed(None) -def load_or_create_client_key(key_file): +def load_or_create_client_key(key_file: str) -> JWKRSA: """Load the ACME account key from a file, creating it if it does not exist. Args: - key_file (str): name of the file to use as the account key + key_file: name of the file to use as the account key """ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't # hardcode the 'client.key' filename diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6f746711ca..a19c556437 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -568,16 +568,6 @@ class AuthHandler(BaseHandler): session.session_id, login_type, result ) except LoginError as e: - if login_type == LoginType.EMAIL_IDENTITY: - # riot used to have a bug where it would request a new - # validation token (thus sending a new email) each time it - # got a 401 with a 'flows' field. - # (https://github.com/vector-im/vector-web/issues/2447). - # - # Grandfather in the old behaviour for now to avoid - # breaking old riot deployments. - raise - # this step failed. Merge the error dict into the response # so that the client can have another go. errordict = e.error_dict() diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048523ec94..bd35d1fb87 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -100,11 +100,7 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s%s?%s" % ( - self._cas_service_url, - "/_matrix/client/r0/login/cas/ticket", - urllib.parse.urlencode(args), - ) + return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index c4a3b26a84..ac25e3e94f 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -50,7 +50,7 @@ class DeactivateAccountHandler(BaseHandler): if hs.config.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) - self._account_validity_enabled = hs.config.account_validity.enabled + self._account_validity_enabled = hs.config.account_validity_enabled async def deactivate_account( self, @@ -120,6 +120,9 @@ class DeactivateAccountHandler(BaseHandler): await self.store.user_set_password_hash(user_id, None) + user = UserID.from_string(user_id) + await self._profile_handler.set_active([user], False, False) + # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) await self.store.add_user_pending_deactivation(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index debb1b4f29..0863154f7a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: + async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ Retrieve the given user's devices @@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device(self, user_id: str, device_id: str) -> JsonDict: """ Retrieve the given device Args: @@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] + device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -946,8 +946,8 @@ class DeviceListUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[Dict[str, Any]], - self_signing_key: Optional[Dict[str, Any]], + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], ) -> List[str]: """Process the given new master and self-signing key for the given remote user. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 929752150d..8f3a6b35a4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( + JsonDict, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class E2eKeysHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -78,7 +82,9 @@ class E2eKeysHandler: ) @trace - async def query_devices(self, query_body, timeout, from_user_id): + async def query_devices( + self, query_body: JsonDict, timeout: int, from_user_id: str + ) -> JsonDict: """ Handle a device key query from a client { @@ -98,12 +104,14 @@ class E2eKeysHandler: } Args: - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] # separate users by domain. # make a map from domain to user_id to device_ids @@ -121,7 +129,8 @@ class E2eKeysHandler: set_tag("remote_key_query", remote_queries) # First get local devices. - failures = {} + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -135,9 +144,10 @@ class E2eKeysHandler: ) # Now attempt to get any remote devices from our local cache. - remote_queries_not_in_cache = {} + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] if remote_queries: - query_list = [] + query_list = [] # type: List[Tuple[str, Optional[str]]] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) @@ -284,15 +294,15 @@ class E2eKeysHandler: return ret async def get_cross_signing_keys_from_cache( - self, query, from_user_id + self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: - query (Iterable[string]) an iterable of user IDs. A dict whose keys + query: an iterable of user IDs. A dict whose keys are user IDs satisfies this, so the query format used for query_devices can be used here. - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. @@ -315,14 +325,12 @@ class E2eKeysHandler: if "self_signing" in user_info: self_signing_keys[user_id] = user_info["self_signing"] - if ( - from_user_id in keys - and keys[from_user_id] is not None - and "user_signing" in keys[from_user_id] - ): - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id: + from_user_key = keys.get(from_user_id) + if from_user_key and "user_signing" in from_user_key: + user_signing_keys[from_user_id] = from_user_key["user_signing"] return { "master_keys": master_keys, @@ -344,9 +352,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] + local_query = [] # type: List[Tuple[str, Optional[str]]] - result_dict = {} + result_dict = {} # type: Dict[str, Dict[str, dict]] for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -380,10 +388,14 @@ class E2eKeysHandler: log_kv(results) return result_dict - async def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys( + self, query_body: Dict[str, Dict[str, Optional[List[str]]]] + ) -> JsonDict: """ Handle a device key query from a federated server """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Optional[List[str]]] res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -397,31 +409,34 @@ class E2eKeysHandler: return ret @trace - async def claim_one_time_keys(self, query, timeout): - local_query = [] - remote_queries = {} + async def claim_one_time_keys( + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + ) -> JsonDict: + local_query = [] # type: List[Tuple[str, str, str]] + remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] - for user_id, device_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in device_keys.items(): + for device_id, algorithm in one_time_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_keys + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) results = await self.store.claim_e2e_one_time_keys(local_query) - json_result = {} - failures = {} + # A map of user ID -> device ID -> key ID -> key. + json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + failures = {} # type: Dict[str, JsonDict] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_bytes) + key_id: json_decoder.decode(json_str) } @trace @@ -468,7 +483,9 @@ class E2eKeysHandler: return {"one_time_keys": json_result, "failures": failures} @tag_args - async def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: time_now = self.clock.time_msec() @@ -543,8 +560,8 @@ class E2eKeysHandler: return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( - self, user_id, device_id, time_now, one_time_keys - ): + self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict + ) -> None: logger.info( "Adding one_time_keys %r for device %r for user %r at %d", one_time_keys.keys(), @@ -585,12 +602,14 @@ class E2eKeysHandler: log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - async def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user( + self, user_id: str, keys: JsonDict + ) -> JsonDict: """Upload signing keys for cross-signing Args: - user_id (string): the user uploading the keys - keys (dict[string, dict]): the signing keys + user_id: the user uploading the keys + keys: the signing keys """ # if a master key is uploaded, then check it. Otherwise, load the @@ -667,16 +686,17 @@ class E2eKeysHandler: return {} - async def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys( + self, user_id: str, signatures: JsonDict + ) -> JsonDict: """Upload device signatures for cross-signing Args: - user_id (string): the user uploading the signatures - signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys. This is the submission from the user; an - exception will be raised if it is malformed. + user_id: the user uploading the signatures + signatures: map of users to devices to signed keys. This is the submission + from the user; an exception will be raised if it is malformed. Returns: - dict: response to be sent back to the client. The response will have + The response to be sent back to the client. The response will have a "failures" key, which will be a dict mapping users to devices to errors for the signatures that failed. Raises: @@ -719,7 +739,9 @@ class E2eKeysHandler: return {"failures": failures} - async def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures( + self, user_id: str, signatures: JsonDict + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -731,15 +753,14 @@ class E2eKeysHandler: signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure - reasons + A tuple of a list of signatures to store, and a map of users to + devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -834,19 +855,24 @@ class E2eKeysHandler: return signature_list, failures def _check_master_key_signature( - self, user_id, master_key_id, signed_master_key, stored_master_key, devices - ): + self, + user_id: str, + master_key_id: str, + signed_master_key: JsonDict, + stored_master_key: JsonDict, + devices: Dict[str, Dict[str, JsonDict]], + ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user whose master key is being checked - master_key_id (string): the ID of the user's master key - signed_master_key (dict): the user's signed master key that was uploaded - stored_master_key (dict): our previously-stored copy of the user's master key - devices (iterable(dict)): the user's devices + user_id: the user whose master key is being checked + master_key_id: the ID of the user's master key + signed_master_key: the user's signed master key that was uploaded + stored_master_key: our previously-stored copy of the user's master key + devices: the user's devices Returns: - list[SignatureListItem]: a list of signatures to store + A list of signatures to store Raises: SynapseError: if a signature is invalid @@ -877,25 +903,26 @@ class E2eKeysHandler: return master_key_signature_list - async def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures( + self, user_id: str, signatures: Dict[str, dict] + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id: the user uploading the keys + signatures: map of users to devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure + A list of signatures to store, and a map of users to devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -983,7 +1010,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None - ): + ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -997,8 +1024,7 @@ class E2eKeysHandler: This affects what signatures are fetched. Returns: - dict, str, VerifyKey: the raw key data, the key ID, and the - signedjson verify key + The raw key data, the key ID, and the signedjson verify key Raises: NotFoundError: if the key is not found @@ -1135,16 +1161,18 @@ class E2eKeysHandler: return desired_key, desired_key_id, desired_verify_key -def _check_cross_signing_key(key, user_id, key_type, signing_key=None): +def _check_cross_signing_key( + key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None +) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. Args: - key (dict): the key data to verify - user_id (str): the user whose key is being checked - key_type (str): the type of key that the key should be - signing_key (VerifyKey): (optional) the signing key that the key should - be signed with. If omitted, signatures will not be checked. + key: the key data to verify + user_id: the user whose key is being checked + key_type: the type of key that the key should be + signing_key: the signing key that the key should be signed with. If + omitted, signatures will not be checked. """ if ( key.get("user_id") != user_id @@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) -def _check_device_signature(user_id, verify_key, signed_device, stored_device): +def _check_device_signature( + user_id: str, + verify_key: VerifyKey, + signed_device: JsonDict, + stored_device: JsonDict, +) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an exception if an error is detected. Args: - user_id (str): the user ID whose signature is being checked - verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the uploaded signed device data - stored_device (dict): our previously stored copy of the device + user_id: the user ID whose signature is being checked + verify_key: the key to verify the device with + signed_device: the uploaded signed device data + stored_device: our previously stored copy of the device Raises: SynapseError: if the signature was invalid or the sent device is not the @@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) -def _exception_to_failure(e): +def _exception_to_failure(e: Exception) -> JsonDict: if isinstance(e, SynapseError): return {"status": e.code, "errcode": e.errcode, "message": str(e)} @@ -1218,7 +1251,7 @@ def _exception_to_failure(e): return {"status": 503, "message": str(e)} -def _one_time_keys_match(old_key_json, new_key): +def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly @@ -1239,16 +1272,16 @@ class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ - signing_key_id = attr.ib() - target_user_id = attr.ib() - target_device_id = attr.ib() - signature = attr.ib() + signing_key_id = attr.ib(type=str) + target_user_id = attr.ib(type=str) + target_device_id = attr.ib(type=str) + signature = attr.ib(type=JsonDict) class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs, e2e_keys_handler): + def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater: iterable=True, ) - async def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. Args: - origin (string): the server that sent the EDU - edu_content (dict): the contents of the EDU + origin: the server that sent the EDU + edu_content: the contents of the EDU """ user_id = edu_content.pop("user_id") @@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater: await self._handle_signing_key_updates(user_id) - async def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id: str) -> None: """Actually handle pending updates. Args: - user_id (string): the user whose updates we are processing + user_id: the user whose updates we are processing """ device_handler = self.e2e_keys_handler.device_handler @@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] + device_ids = [] # type: List[str] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f01b090772..622cae23be 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py
@@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import ( Codes, @@ -24,8 +25,12 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class E2eRoomKeysHandler: The actual payload of the encrypted keys is completely opaque to the handler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # Used to lock whenever a client is uploading key data. This prevents collisions @@ -48,21 +53,27 @@ class E2eRoomKeysHandler: self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - async def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> List[JsonDict]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. Args: - user_id(str): the user whose keys we're getting - version(str): the version ID of the backup we're getting keys from - room_id(string): room ID to get keys for, for None to get keys for all rooms - session_id(string): session ID to get keys for, for None to get keys for all + user_id: the user whose keys we're getting + version: the version ID of the backup we're getting keys from + room_id: room ID to get keys for, for None to get keys for all rooms + session_id: session ID to get keys for, for None to get keys for all sessions Raises: NotFoundError: if the backup version does not exist Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -86,17 +97,23 @@ class E2eRoomKeysHandler: return results @trace - async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. Args: - user_id(str): the user whose backup we're deleting - version(str): the version ID of the backup we're deleting - room_id(string): room ID to delete keys for, for None to delete keys for all + user_id: the user whose backup we're deleting + version: the version ID of the backup we're deleting + room_id: room ID to delete keys for, for None to delete keys for all rooms - session_id(string): session ID to delete keys for, for None to delete keys + session_id: session ID to delete keys for, for None to delete keys for all sessions Raises: NotFoundError: if the backup version does not exist @@ -128,15 +145,17 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @trace - async def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys( + self, user_id: str, version: str, room_keys: JsonDict + ) -> JsonDict: """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_keys(dict): a nested dict describing the room_keys we're setting: + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_keys: a nested dict describing the room_keys we're setting: { "rooms": { @@ -254,14 +273,16 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @staticmethod - def _should_replace_room_key(current_room_key, room_key): + def _should_replace_room_key( + current_room_key: Optional[JsonDict], room_key: JsonDict + ) -> bool: """ Determine whether to replace a given current_room_key (if any) with a newly uploaded room_key backup Args: - current_room_key (dict): Optional, the current room_key dict if any - room_key (dict): The new room_key dict which may or may not be fit to + current_room_key: Optional, the current room_key dict if any + room_key : The new room_key dict which may or may not be fit to replace the current_room_key Returns: @@ -286,14 +307,14 @@ class E2eRoomKeysHandler: return True @trace - async def create_version(self, user_id, version_info): + async def create_version(self, user_id: str, version_info: JsonDict) -> str: """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. Args: - user_id(str): the user whose backup version we're creating - version_info(dict): metadata about the new version being created + user_id: the user whose backup version we're creating + version_info: metadata about the new version being created { "algorithm": "m.megolm_backup.v1", @@ -301,7 +322,7 @@ class E2eRoomKeysHandler: } Returns: - A deferred of a string that gives the new version number. + The new version number. """ # TODO: Validate the JSON to make sure it has the right keys. @@ -313,17 +334,19 @@ class E2eRoomKeysHandler: ) return new_version - async def get_version_info(self, user_id, version=None): + async def get_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're querying - version(str): Optional; if None gives the most recent version + user_id: the user whose current backup version we're querying + version: Optional; if None gives the most recent version otherwise a historical one. Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of a info dict that gives the info about the new version. + A info dict that gives the info about the new version. { "version": "1234", @@ -346,7 +369,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id, version=None): + async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: @@ -366,17 +389,19 @@ class E2eRoomKeysHandler: raise @trace - async def update_version(self, user_id, version, version_info): + async def update_version( + self, user_id: str, version: str, version_info: JsonDict + ) -> JsonDict: """Update the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're updating - version(str): the backup version we're updating - version_info(dict): the new information about the backup + user_id: the user whose current backup version we're updating + version: the backup version we're updating + version_info: the new information about the backup Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of an empty dict. + An empty dict. """ if "version" not in version_info: version_info["version"] = version diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..fa56b31438 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -186,7 +187,7 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info("handling received PDU: %s", pdu) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = await self.store.get_event( @@ -301,6 +302,14 @@ class FederationHandler(BaseHandler): room_id, event_id, ) + elif missing_prevs: + logger.info( + "[%s %s] Not recursively fetching %d missing prev_events: %s", + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), + ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -345,12 +354,6 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) - logger.info( - "Event %s is missing prev_events: calculating state for a " - "backwards extremity", - event_id, - ) - # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. event_map = {event_id: pdu} @@ -368,7 +371,10 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "Requesting state at missing prev_event %s", event_id, + "[%s %s] Requesting state at missing prev_event %s", + room_id, + event_id, + p, ) with nested_logging_context(p): @@ -402,9 +408,7 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, + list(state_map.values()), get_prev_content=False, ) event_map.update(evs) @@ -1439,6 +1443,73 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) + @log_function + async def do_knock( + self, target_hosts: List[str], room_id: str, knockee: str, content: JsonDict, + ) -> Tuple[str, int]: + """Sends the knock to the remote server. + + This first triggers a make_knock request that returns a partial + event that we can fill out and sign. This is then sent to the + remote server via send_knock. + + Knock events must be signed by the knockee's server before distributing. + + Args: + target_hosts: A list of hosts that we want to try knocking through. + room_id: The ID of the room to knock on. + knockee: The ID of the user who is knocking. + content: The content of the knock event. + + Returns: + A tuple of (event ID, stream ID). + + Raises: + SynapseError: If the chosen remote server returns a 3xx/4xx code. + RuntimeError: If no servers were reachable. + """ + logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee) + + # Inform the remote server of the room versions we support + supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys()) + + # Ask the remote server to create a valid knock event for us. Once received, + # we sign the event + params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] + origin, event, event_format_version = await self._make_and_verify_event( + target_hosts, room_id, knockee, Membership.KNOCK, content, params=params + ) + + # Record the room ID and its version so that we have a record of the room + await self._maybe_store_room_on_outlier_membership( + room_id=event.room_id, room_version=event_format_version + ) + + # Initially try the host that we successfully called /make_knock on + try: + target_hosts.remove(origin) + target_hosts.insert(0, origin) + except ValueError: + pass + + # Send the signed event back to the room, and potentially receive some + # further information about the room in the form of partial state events + stripped_room_state = await self.federation_client.send_knock( + target_hosts, event + ) + + # Store any stripped room state events in the "unsigned" key of the event. + # This is a bit of a hack and is cribbing off of invites. Basically we + # store the room state here and retrieve it again when this event appears + # in the invitee's sync stream. It is stripped out for all other local users. + event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] + + context = await self.state_handler.compute_event_context(event) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) + return event.event_id, stream_id + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. @@ -1593,8 +1664,15 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") + is_published = await self.store.is_room_published(event.room_id) + if not await self.spam_checker.user_may_invite( - event.sender, event.state_key, event.room_id + event.sender, + event.state_key, + None, + room_id=event.room_id, + new_room=False, + published_room=is_published, ): raise SynapseError( 403, "This user is not permitted to send invites to this server/user" @@ -1617,6 +1695,10 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + member_handler.ratelimit_invite(event.room_id, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). @@ -1775,6 +1857,120 @@ class FederationHandler(BaseHandler): return None + @log_function + async def on_make_knock_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: + """We've received a make_knock request, so we create a partial + knock event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: The room to create the knock event in. + user_id: The user to create the knock for. + + Returns: + The partial knock event. + """ + if get_domain_from_id(user_id) != origin: + logger.info( + "Get /xyz.amorgan.knock/make_knock request for user %r" + "from different origin %s, ignoring", + user_id, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + room_version = await self.store.get_room_version_id(room_id) + + builder = self.event_builder_factory.new( + room_version, + { + "type": EventTypes.Member, + "content": {"membership": Membership.KNOCK}, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }, + ) + + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.warning("Creation of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_knock_request` + await self.auth.check_from_context( + room_version, event, context, do_sig_check=False + ) + except AuthError as e: + logger.warning("Failed to create new knock %r because %s", event, e) + raise e + + return event + + @log_function + async def on_send_knock_request( + self, origin: str, event: EventBase + ) -> EventContext: + """ + We have received a knock event for a room. Verify that event and send it into the room + on the knocking homeserver's behalf. + + Args: + origin: The remote homeserver of the knocking user. + event: The knocking member event that has been signed by the remote homeserver. + + Returns: + The context of the event after inserting it into the room graph. + """ + logger.debug( + "on_send_knock_request: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /xyz.amorgan.knock/send_knock request for user %r " + "from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + event.internal_metadata.outlier = False + + context = await self._handle_new_event(origin, event) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + logger.debug( + "on_send_knock_request: After _handle_new_event: %s, sigs: %s", + event.event_id, + event.signatures, + ) + + return context + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: """Returns the state at the event. i.e. not including said event. """ @@ -2093,6 +2289,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index df29edeb83..71f11ef94a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -15,9 +15,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Dict, Iterable, List, Set from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, get_domain_from_id +from synapse.types import GroupID, JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ def _create_rerouter(func_name): class GroupsLocalWorkerHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.room_list_handler = hs.get_room_list_handler() @@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler: get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - async def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the group summary for a group. If the group is remote we check that the users have valid attestations. @@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler: return res - async def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users in a group """ if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_users_in_group( + return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) - return res group_server_name = get_domain_from_id(group_id) @@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler: return res - async def get_joined_groups(self, user_id): + async def get_joined_groups(self, user_id: str) -> JsonDict: group_ids = await self.store.get_joined_groups(user_id) return {"groups": group_ids} - async def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: if self.hs.is_mine_id(user_id): result = await self.store.get_publicised_groups_for_user(user_id) @@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler: # TODO: Verify attestations return {"groups": result} - async def bulk_get_publicised_groups(self, user_ids, proxy=True): - destinations = {} + async def bulk_get_publicised_groups( + self, user_ids: Iterable[str], proxy: bool = True + ) -> JsonDict: + destinations = {} # type: Dict[str, Set[str]] local_users = set() for user_id in user_ids: @@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] + failed_results = [] # type: List[str] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( @@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # Ensure attestations get renewed @@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - async def create_group(self, group_id, user_id, content): + async def create_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Create a group """ @@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = None remote_attestation = None else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - content["user_profile"] = await self.profile_handler.get_profile(user_id) - - try: - res = await self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) + raise SynapseError(400, "Unable to create remote groups") is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def join_group(self, group_id, user_id, content): + async def join_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Request to join a group """ if self.is_mine_id(group_id): @@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def accept_invite(self, group_id, user_id, content): + async def accept_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Accept an invite to a group """ if self.is_mine_id(group_id): @@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def invite(self, group_id, user_id, requester_user_id, config): + async def invite( + self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict + ) -> JsonDict: """Invite a user to a group """ content = {"requester_user_id": requester_user_id, "config": config} @@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def on_invite(self, group_id, user_id, content): + async def on_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """One of our users were invited to a group """ # TODO: Support auto join and rejection @@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} async def remove_user_from_group( - self, group_id, user_id, requester_user_id, content - ): + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Remove a user from a group """ if user_id == requester_user_id: @@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def user_removed_from_group(self, group_id, user_id, content): + async def user_removed_from_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> None: """One of our users was removed/kicked from a group """ # TODO: Check if user in group diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f61844d688..8dbf9bef3f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,14 +22,18 @@ import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple from synapse.api.errors import ( + AuthError, CodeMessageException, Codes, HttpResponseException, + ProxiedRequestError, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -39,8 +43,6 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -id_server_scheme = "https://" - class IdentityHandler(BaseHandler): def __init__(self, hs): @@ -55,17 +57,46 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_federation_http_client() self.hs = hs + self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls + self._enable_lookup = hs.config.enable_3pid_lookup + self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( - self, id_server: str, creds: Dict[str, str] + self, id_server_url: str, creds: Dict[str, str] ) -> Optional[JsonDict]: """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server Args: - id_server: The identity server to validate 3PIDs against. Must be a + id_server_url: The identity server to validate 3PIDs against. Must be a complete URL including the protocol (http(s)://) creds: Dictionary containing the following keys: * client_secret|clientSecret: A unique secret str provided by the client @@ -90,7 +121,14 @@ class IdentityHandler(BaseHandler): query_params = {"sid": session_id, "client_secret": client_secret} - url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) + + url = "%s%s" % ( + id_server_url, + "/_matrix/identity/api/v1/3pid/getValidated3pid", + ) try: data = await self.http_client.get_json(url, query_params) @@ -99,7 +137,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info( "%s returned %i for threepid validation for: %s", - id_server, + id_server_url, e.code, creds, ) @@ -113,7 +151,7 @@ class IdentityHandler(BaseHandler): if "medium" in data: return data - logger.info("%s reported non-validated threepid: %s", id_server, creds) + logger.info("%s reported non-validated threepid: %s", id_server_url, creds) return None async def bind_threepid( @@ -145,14 +183,19 @@ class IdentityHandler(BaseHandler): if id_access_token is None: use_v2 = False + # if we have a rewrite rule set for the identity server, + # apply it now, but only for sending the request (not + # storing in the database). + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + # Decide which API endpoint URLs to use headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} if use_v2: - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) + bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,) headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore else: - bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) + bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,) try: # Use the blacklisting http client as this call is only to identity servers @@ -239,9 +282,6 @@ class IdentityHandler(BaseHandler): True on success, otherwise False if the identity server doesn't support unbinding """ - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") - content = { "mxid": mxid, "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, @@ -250,6 +290,7 @@ class IdentityHandler(BaseHandler): # we abuse the federation http client to sign the request, but we have to send it # using the normal http client since we don't want the SRV lookup and want normal # 'browser-like' HTTPS. + url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") auth_headers = self.federation_http_client.build_auth_headers( destination=None, method=b"POST", @@ -259,6 +300,15 @@ class IdentityHandler(BaseHandler): ) headers = {b"Authorization": auth_headers} + # if we have a rewrite rule set for the identity server, + # apply it now. + # + # Note that destination_is has to be the real id_server, not + # the server we connect to. + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,) + try: # Use the blacklisting http client as this call is only to identity servers # provided by a client @@ -373,9 +423,28 @@ class IdentityHandler(BaseHandler): return session_id + def rewrite_id_server_url(self, url: str, add_https=False) -> str: + """Given an identity server URL, optionally add a protocol scheme + before rewriting it according to the rewrite_identity_server_urls + config option + + Adds https:// to the URL if specified, then tries to rewrite the + url. Returns either the rewritten URL or the URL with optional + protocol scheme additions. + """ + rewritten_url = url + if add_https: + rewritten_url = "https://" + rewritten_url + + rewritten_url = self.rewrite_identity_server_urls.get( + rewritten_url, rewritten_url + ) + logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url) + return rewritten_url + async def requestEmailToken( self, - id_server: str, + id_server_url: str, email: str, client_secret: str, send_attempt: int, @@ -386,7 +455,7 @@ class IdentityHandler(BaseHandler): validation. Args: - id_server: The identity server to proxy to + id_server_url: The identity server to proxy to email: The email to send the message to client_secret: The unique client_secret sends by the user send_attempt: Which attempt this is @@ -400,6 +469,11 @@ class IdentityHandler(BaseHandler): "client_secret": client_secret, "send_attempt": send_attempt, } + + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) + if next_link: params["next_link"] = next_link @@ -414,7 +488,8 @@ class IdentityHandler(BaseHandler): try: data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", + "%s/_matrix/identity/api/v1/validate/email/requestToken" + % (id_server_url,), params, ) return data @@ -426,7 +501,7 @@ class IdentityHandler(BaseHandler): async def requestMsisdnToken( self, - id_server: str, + id_server_url: str, country: str, phone_number: str, client_secret: str, @@ -437,7 +512,7 @@ class IdentityHandler(BaseHandler): Request an external server send an SMS message on our behalf for the purposes of threepid validation. Args: - id_server: The identity server to proxy to + id_server_url: The identity server to proxy to country: The country code of the phone number phone_number: The number to send the message to client_secret: The unique client_secret sends by the user @@ -465,9 +540,13 @@ class IdentityHandler(BaseHandler): "details and update your config file." ) + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) try: data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", + "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" + % (id_server_url,), params, ) except HttpResponseException as e: @@ -559,6 +638,86 @@ class IdentityHandler(BaseHandler): logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) raise SynapseError(400, "Error contacting the identity server") + # TODO: The following two methods are used for proxying IS requests using + # the CS API. They should be consolidated with those in RoomMemberHandler + # https://github.com/matrix-org/synapse-dinsic/issues/25 + + async def proxy_lookup_3pid( + self, id_server: str, medium: str, address: str + ) -> JsonDict: + """Looks up a 3pid in the passed identity server. + + Args: + id_server: The server name (including port, if required) + of the identity server to use. + medium: The type of the third party identifier (e.g. "email"). + address: The third party identifier (e.g. "foo@example.com"). + + Returns: + The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + try: + data = await self.http_client.get_json( + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), + {"medium": medium, "address": address}, + ) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %s: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + return data + + async def proxy_bulk_lookup_3pid( + self, id_server: str, threepids: List[List[str]] + ) -> JsonDict: + """Looks up given 3pids in the passed identity server. + + Args: + id_server: The server name (including port, if required) + of the identity server to use. + threepids: The third party identifiers to lookup, as + a list of 2-string sized lists ([medium, address]). + + Returns: + The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + try: + data = await self.http_client.post_json_get_json( + "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,), + {"threepids": threepids}, + ) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %s: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + return data + async def lookup_3pid( self, id_server: str, @@ -579,10 +738,13 @@ class IdentityHandler(BaseHandler): Returns: the matrix ID of the 3pid, or None if it is not recognized. """ + # Rewrite id_server URL if necessary + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + if id_access_token is not None: try: results = await self._lookup_3pid_v2( - id_server, id_access_token, medium, address + id_server_url, id_access_token, medium, address ) return results @@ -600,16 +762,17 @@ class IdentityHandler(BaseHandler): logger.warning("Error when looking up hashing details: %s", e) return None - return await self._lookup_3pid_v1(id_server, medium, address) + return await self._lookup_3pid_v1(id_server, id_server_url, medium, address) async def _lookup_3pid_v1( - self, id_server: str, medium: str, address: str + self, id_server: str, id_server_url: str, medium: str, address: str ) -> Optional[str]: """Looks up a 3pid in the passed identity server using v1 lookup. Args: id_server: The server name (including port, if required) of the identity server to use. + id_server_url: The actual, reachable domain of the id server medium: The type of the third party identifier (e.g. "email"). address: The third party identifier (e.g. "foo@example.com"). @@ -617,8 +780,8 @@ class IdentityHandler(BaseHandler): the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + data = await self.http_client.get_json( + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) @@ -635,13 +798,12 @@ class IdentityHandler(BaseHandler): return None async def _lookup_3pid_v2( - self, id_server: str, id_access_token: str, medium: str, address: str + self, id_server_url: str, id_access_token: str, medium: str, address: str ) -> Optional[str]: """Looks up a 3pid in the passed identity server using v2 lookup. Args: - id_server: The server name (including port, if required) - of the identity server to use. + id_server_url: The protocol scheme and domain of the id server id_access_token: The access token to authenticate to the identity server with medium: The type of the third party identifier (e.g. "email"). address: The third party identifier (e.g. "foo@example.com"). @@ -651,8 +813,8 @@ class IdentityHandler(BaseHandler): """ # Check what hashing details are supported by this identity server try: - hash_details = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), + hash_details = await self.http_client.get_json( + "%s/_matrix/identity/v2/hash_details" % (id_server_url,), {"access_token": id_access_token}, ) except RequestTimedOutError: @@ -660,15 +822,14 @@ class IdentityHandler(BaseHandler): if not isinstance(hash_details, dict): logger.warning( - "Got non-dict object when checking hash details of %s%s: %s", - id_server_scheme, - id_server, + "Got non-dict object when checking hash details of %s: %s", + id_server_url, hash_details, ) raise SynapseError( 400, - "Non-dict object from %s%s during v2 hash_details request: %s" - % (id_server_scheme, id_server, hash_details), + "Non-dict object from %s during v2 hash_details request: %s" + % (id_server_url, hash_details), ) # Extract information from hash_details @@ -682,8 +843,8 @@ class IdentityHandler(BaseHandler): ): raise SynapseError( 400, - "Invalid hash details received from identity server %s%s: %s" - % (id_server_scheme, id_server, hash_details), + "Invalid hash details received from identity server %s: %s" + % (id_server_url, hash_details), ) # Check if any of the supported lookup algorithms are present @@ -705,7 +866,7 @@ class IdentityHandler(BaseHandler): else: logger.warning( "None of the provided lookup algorithms of %s are supported: %s", - id_server, + id_server_url, supported_lookup_algorithms, ) raise SynapseError( @@ -718,8 +879,8 @@ class IdentityHandler(BaseHandler): headers = {"Authorization": create_id_access_token_header(id_access_token)} try: - lookup_results = await self.blacklisting_http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), + lookup_results = await self.http_client.post_json_get_json( + "%s/_matrix/identity/v2/lookup" % (id_server_url,), { "addresses": [lookup_value], "algorithm": lookup_algorithm, @@ -807,15 +968,17 @@ class IdentityHandler(BaseHandler): if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location + # Rewrite the identity server URL if necessary + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + # Add the identity service access token to the JSON body and use the v2 # Identity Service endpoints if id_access_token is present data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) + base_url = "%s/_matrix/identity" % (id_server_url,) if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_url, ) # Attempt a v2 lookup @@ -834,9 +997,8 @@ class IdentityHandler(BaseHandler): raise e if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_url, ) url = base_url + "/api/v1/store-invite" @@ -848,10 +1010,7 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, - e, + "Error trying to call /store-invite on %s: %s", id_server_url, e, ) if data is None: @@ -864,10 +1023,9 @@ class IdentityHandler(BaseHandler): ) except HttpResponseException as e: logger.warning( - "Error calling /store-invite on %s%s with fallback " + "Error calling /store-invite on %s with fallback " "encoding: %s", - id_server_scheme, - id_server, + id_server_url, e, ) raise e @@ -888,6 +1046,42 @@ class IdentityHandler(BaseHandler): display_name = data["display_name"] return token, public_keys, fallback_public_key, display_name + async def bind_email_using_internal_sydent_api( + self, id_server_url: str, email: str, user_id: str, + ): + """Bind an email to a fully qualified user ID using the internal API of an + instance of Sydent. + + Args: + id_server_url: The URL of the Sydent instance + email: The email address to bind + user_id: The user ID to bind the email to + + Raises: + HTTPResponseException: On a non-2xx HTTP response. + """ + # Extract the domain name from the IS URL as we store IS domains instead of URLs + id_server = urllib.parse.urlparse(id_server_url).hostname + if not id_server: + # We were unable to determine the hostname, bail out + return + + # id_server_url is assumed to have no trailing slashes + url = id_server_url + "/_matrix/identity/internal/bind" + body = { + "address": email, + "medium": "email", + "mxid": user_id, + } + + # Bind the threepid + await self.http_client.post_json_get_json(url, body) + + # Remember where we bound the threepid + await self.store.add_user_bound_threepid( + user_id=user_id, medium="email", address=email, id_server=id_server, + ) + def create_id_access_token_header(id_access_token: str) -> List[str]: """Create an Authorization header for passing to SimpleHttpClient as the header value diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..f3694e0973 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyrignt 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,6 +41,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.config.api import DEFAULT_ROOM_STATE_TYPES from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext @@ -432,6 +434,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -492,7 +496,7 @@ class EventCreationHandler: membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership in {Membership.JOIN, Membership.INVITE}: + if membership in {Membership.JOIN, Membership.INVITE, Membership.KNOCK}: # If event doesn't include a display name, add one. profile = self.profile_handler content = builder.content @@ -918,8 +922,8 @@ class EventCreationHandler: room_version = await self.store.get_room_version_id(event.room_id) if event.internal_metadata.is_out_of_band_membership(): - # the only sort of out-of-band-membership events we expect to see here - # are invite rejections we have generated ourselves. + # the only sort of out-of-band-membership events we expect to see here are + # invite rejections and rescinded knocks that we have generated ourselves. assert event.type == EventTypes.Member assert event.content["membership"] == Membership.LEAVE else: @@ -939,6 +943,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +984,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: @@ -1125,6 +1169,13 @@ class EventCreationHandler: # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) + if event.content["membership"] == Membership.KNOCK: + event.unsigned[ + "knock_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, DEFAULT_ROOM_STATE_TYPES, + ) + if event.type == EventTypes.Redaction: original_event = await self.store.get_event( event.redacts, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c02b951031..4b102ff9a9 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +15,11 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional + +from signedjson.sign import sign_json + +from twisted.internet import reactor from synapse.api.errors import ( AuthError, @@ -24,7 +29,11 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, Requester, @@ -54,6 +63,8 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000 + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -64,11 +75,98 @@ class ProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() + + self.max_avatar_size = hs.config.max_avatar_size + self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes + self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to + if hs.config.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS ) + if len(self.hs.config.replicate_user_profiles_to) > 0: + reactor.callWhenRunning(self._do_assign_profile_replication_batches) + reactor.callWhenRunning(self._start_replicate_profiles) + # Add a looping call to replicate_profiles: this handles retries + # if the replication is unsuccessful when the user updated their + # profile. + self.clock.looping_call( + self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL + ) + + def _do_assign_profile_replication_batches(self): + return run_as_background_process( + "_assign_profile_replication_batches", + self._assign_profile_replication_batches, + ) + + def _start_replicate_profiles(self): + return run_as_background_process( + "_replicate_profiles", self._replicate_profiles + ) + + async def _assign_profile_replication_batches(self): + """If no profile replication has been done yet, allocate replication batch + numbers to each profile to start the replication process. + """ + logger.info("Assigning profile batch numbers...") + total = 0 + while True: + assigned = await self.store.assign_profile_batch() + total += assigned + if assigned == 0: + break + logger.info("Assigned %d profile batch numbers", total) + + async def _replicate_profiles(self): + """If any profile data has been updated and not pushed to the replication targets, + replicate it. + """ + host_batches = await self.store.get_replication_hosts() + latest_batch = await self.store.get_latest_profile_replication_batch_number() + if latest_batch is None: + latest_batch = -1 + for repl_host in self.hs.config.replicate_user_profiles_to: + if repl_host not in host_batches: + host_batches[repl_host] = -1 + try: + for i in range(host_batches[repl_host] + 1, latest_batch + 1): + await self._replicate_host_profile_batch(repl_host, i) + except Exception: + logger.exception( + "Exception while replicating to %s: aborting for now", repl_host + ) + + async def _replicate_host_profile_batch(self, host, batchnum): + logger.info("Replicating profile batch %d to %s", batchnum, host) + batch_rows = await self.store.get_profile_batch(batchnum) + batch = { + UserID(r["user_id"], self.hs.hostname).to_string(): ( + {"display_name": r["displayname"], "avatar_url": r["avatar_url"]} + if r["active"] + else None + ) + for r in batch_rows + } + + url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,) + body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname} + signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) + try: + await self.http_client.post_json_get_json(url, signed_body) + await self.store.update_replication_batch_for_host(host, batchnum) + logger.info( + "Successfully replicated profile batch %d to %s", batchnum, host + ) + except Exception: + # This will get retried when the looping call next comes around + logger.exception( + "Failed to replicate profile batch %d to %s", batchnum, host + ) + raise + async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) @@ -210,8 +308,16 @@ class ProfileHandler(BaseHandler): target_user, authenticated_entity=requester.authenticated_entity, ) + if len(self.hs.config.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + await self.store.set_profile_displayname( - target_user.localpart, displayname_to_set + target_user.localpart, displayname_to_set, new_batchnum ) if self.hs.config.user_directory_search_all_users: @@ -222,6 +328,46 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) + # start a profile replication push + run_in_background(self._replicate_profiles) + + async def set_active( + self, users: List[UserID], active: bool, hide: bool, + ): + """ + Sets the 'active' flag on a set of user profiles. If set to false, the + accounts are considered deactivated or hidden. + + If 'hide' is true, then we interpret active=False as a request to try to + hide the users rather than deactivating them. This means withholding the + profiles from replication (and mark it as inactive) rather than clearing + the profile from the HS DB. + + Note that unlike set_displayname and set_avatar_url, this does *not* + perform authorization checks! This is because the only place it's used + currently is in account deactivation where we've already done these + checks anyway. + + Args: + users: The users to modify + active: Whether to set the user to active or inactive + hide: Whether to hide the user (withold from replication). If + False and active is False, user will have their profile + erased + """ + if len(self.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + + await self.store.set_profiles_active(users, active, hide, new_batchnum) + + # start a profile replication push + run_in_background(self._replicate_profiles) + async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: @@ -290,14 +436,56 @@ class ProfileHandler(BaseHandler): if new_avatar_url == "": avatar_url_to_set = None + # Enforce a max avatar size if one is defined + if avatar_url_to_set and ( + self.max_avatar_size or self.allowed_avatar_mimetypes + ): + media_id = self._validate_and_parse_media_id_from_avatar_url( + avatar_url_to_set + ) + + # Check that this media exists locally + media_info = await self.store.get_local_media(media_id) + if not media_info: + raise SynapseError( + 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND + ) + + # Ensure avatar does not exceed max allowed avatar size + media_size = media_info["media_length"] + if self.max_avatar_size and media_size > self.max_avatar_size: + raise SynapseError( + 400, + "Avatars must be less than %s bytes in size" + % (self.max_avatar_size,), + errcode=Codes.TOO_LARGE, + ) + + # Ensure the avatar's file type is allowed + if ( + self.allowed_avatar_mimetypes + and media_info["media_type"] not in self.allowed_avatar_mimetypes + ): + raise SynapseError( + 400, "Avatar file type '%s' not allowed" % media_info["media_type"] + ) + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) + if len(self.hs.config.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + await self.store.set_profile_avatar_url( - target_user.localpart, avatar_url_to_set + target_user.localpart, avatar_url_to_set, new_batchnum ) if self.hs.config.user_directory_search_all_users: @@ -308,6 +496,23 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) + # start a profile replication push + run_in_background(self._replicate_profiles) + + def _validate_and_parse_media_id_from_avatar_url(self, mxc): + """Validate and parse a provided avatar url and return the local media id + + Args: + mxc (str): A mxc URL + + Returns: + str: The ID of the media + """ + avatar_pieces = mxc.split("/") + if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:": + raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc) + return avatar_pieces[-1] + async def on_profile_query(self, args: JsonDict) -> JsonDict: user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 49b085269b..ab4d5ccc1c 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -49,6 +49,7 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() @@ -57,6 +58,8 @@ class RegistrationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() + self._show_in_user_directory = self.hs.config.show_users_in_user_directory + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -77,6 +80,16 @@ class RegistrationHandler(BaseHandler): guest_access_token: Optional[str] = None, assigned_user_id: Optional[str] = None, ): + """ + + Args: + localpart (str|None): The user's localpart + guest_access_token (str|None): A guest's access token + assigned_user_id (str|None): An existing User ID for this user if pre-calculated + + Returns: + Deferred + """ if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, @@ -119,6 +132,8 @@ class RegistrationHandler(BaseHandler): raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) + + # Retrieve guest user information from provided access token user_data = await self.auth.get_user_by_access_token(guest_access_token) if ( not user_data.is_guest @@ -237,6 +252,12 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + if default_display_name: + requester = create_requester(user) + await self.profile_handler.set_displayname( + user, requester, default_display_name, by_admin=True + ) + if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(localpart) await self.user_directory_handler.handle_local_profile_change( @@ -246,8 +267,6 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID fail_count = 0 - # If a default display name is not given, generate one. - generate_display_name = default_display_name is None # This breaks on successful registration *or* errors after 10 failures. while True: # Fail after being unable to find a suitable ID a few times @@ -258,7 +277,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) - if generate_display_name: + if default_display_name is None: default_display_name = localpart try: await self.register_with_store( @@ -270,6 +289,11 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + requester = create_requester(user) + await self.profile_handler.set_displayname( + user, requester, default_display_name, by_admin=True + ) + # Successfully registered break except SynapseError: @@ -301,7 +325,15 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - await self._register_email_threepid(user_id, threepid_dict, None) + await self.register_email_threepid(user_id, threepid_dict, None) + + # Prevent the new user from showing up in the user directory if the server + # mandates it. + if not self._show_in_user_directory: + await self.store.add_account_data_for_user( + user_id, "im.vector.hide_profile", {"hide_profile": True} + ) + await self.profile_handler.set_active([user], False, True) return user_id @@ -501,7 +533,10 @@ class RegistrationHandler(BaseHandler): """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart: str, as_token: str) -> str: + async def appservice_register( + self, user_localpart: str, as_token: str, password_hash: str, display_name: str + ): + # FIXME: this should be factored out and merged with normal register() user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -518,12 +553,26 @@ class RegistrationHandler(BaseHandler): self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service) + display_name = display_name or user.localpart + await self.register_with_store( user_id=user_id, - password_hash="", + password_hash=password_hash, appservice_id=service_id, - create_profile_with_displayname=user.localpart, + create_profile_with_displayname=display_name, + ) + + requester = create_requester(user) + await self.profile_handler.set_displayname( + user, requester, display_name, by_admin=True ) + + if self.hs.config.user_directory_search_all_users: + profile = await self.store.get_profileinfo(user_localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + return user_id def check_user_id_not_appservice_exclusive( @@ -552,6 +601,37 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) + async def shadow_register(self, localpart, display_name, auth_result, params): + """Invokes the current registration on another server, using + shared secret registration, passing in any auth_results from + other registration UI auth flows (e.g. validated 3pids) + Useful for setting up shadow/backup accounts on a parallel deployment. + """ + + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token), + { + # XXX: auth_result is an unspecified extension for shadow registration + "auth_result": auth_result, + # XXX: another unspecified extension for shadow registration to ensure + # that the displayname is correctly set by the masters erver + "display_name": display_name, + "username": localpart, + "password": params.get("password"), + "bind_msisdn": params.get("bind_msisdn"), + "device_id": params.get("device_id"), + "initial_device_display_name": params.get( + "initial_device_display_name" + ), + "inhibit_login": False, + "access_token": as_token, + }, + ) + def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -704,6 +784,7 @@ class RegistrationHandler(BaseHandler): if auth_result and LoginType.EMAIL_IDENTITY in auth_result: threepid = auth_result[LoginType.EMAIL_IDENTITY] + # Necessary due to auth checks prior to the threepid being # written to the db if is_threepid_reserved( @@ -711,7 +792,32 @@ class RegistrationHandler(BaseHandler): ): await self.store.upsert_monthly_active_user(user_id) - await self._register_email_threepid(user_id, threepid, access_token) + await self.register_email_threepid(user_id, threepid, access_token) + + if self.hs.config.bind_new_user_emails_to_sydent: + # Attempt to call Sydent's internal bind API on the given identity server + # to bind this threepid + id_server_url = self.hs.config.bind_new_user_emails_to_sydent + + logger.debug( + "Attempting the bind email of %s to identity server: %s using " + "internal Sydent bind API.", + user_id, + self.hs.config.bind_new_user_emails_to_sydent, + ) + + try: + await self.identity_handler.bind_email_using_internal_sydent_api( + id_server_url, threepid["address"], user_id + ) + except Exception as e: + logger.warning( + "Failed to bind email of '%s' to Sydent instance '%s' ", + "using Sydent internal bind API: %s", + user_id, + id_server_url, + e, + ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] @@ -731,7 +837,7 @@ class RegistrationHandler(BaseHandler): await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid( + async def register_email_threepid( self, user_id: str, threepid: dict, token: Optional[str] ) -> None: """Add an email address as a 3pid identifier diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ee27d99135..d037742081 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -359,7 +363,19 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not await self.spam_checker.user_may_create_room(user_id): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to create rooms + is_requester_admin = True + + else: + is_requester_admin = await self.auth.is_server_admin(requester.user) + + if not is_requester_admin and not await self.spam_checker.user_may_create_room( + user_id, invite_list=[], third_party_invite_list=[], cloning=True + ): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -610,8 +626,14 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) + invite_list = config.get("invite", []) + invite_3pid_list = config.get("invite_3pid", []) + if not is_requester_admin and not await self.spam_checker.user_may_create_room( - user_id + user_id, + invite_list=invite_list, + third_party_invite_list=invite_3pid_list, + cloning=False, ): raise SynapseError(403, "You are not permitted to create rooms") @@ -662,6 +684,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") @@ -796,6 +821,7 @@ class RoomCreationHandler(BaseHandler): "invite", ratelimit=False, content=content, + new_room=True, ) for invite_3pid in invite_3pid_list: @@ -813,6 +839,7 @@ class RoomCreationHandler(BaseHandler): id_server, requester, txn_id=None, + new_room=True, id_access_token=id_access_token, ) @@ -890,6 +917,7 @@ class RoomCreationHandler(BaseHandler): "join", ratelimit=ratelimit, content=creator_join_profile, + new_room=True, ) # We treat the power levels override specially as this needs to be one diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 14f14db449..373b9dcd0d 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -170,6 +170,7 @@ class RoomListHandler(BaseHandler): "world_readable": room["history_visibility"] == HistoryVisibility.WORLD_READABLE, "guest_can_join": room["guest_access"] == "can_join", + "join_rule": room["join_rules"], } # Filter out Nones – rather omit the field altogether diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e001e418f9..a92f7ba012 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import abc import logging import random @@ -31,7 +31,15 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomAlias, + RoomID, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room @@ -85,6 +93,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -111,6 +130,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def remote_knock( + self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict, + ) -> Tuple[str, int]: + """Try and knock on a room that this server is not in + + Args: + remote_room_hosts: List of servers that can be used to knock via. + room_id: Room that we are trying to knock on. + user: User who is trying to knock. + content: A dict that should be used as the content of the knock event. + """ + raise NotImplementedError() + + @abc.abstractmethod async def remote_reject_invite( self, invite_event_id: str, @@ -134,6 +167,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """Rescind a local knock made on a remote room. + + Args: + knock_event_id: The ID of the knock event to rescind. + txn_id: An optional transaction ID supplied by the client. + requester: The user making the request, according to the access token. + content: The content of the generated leave event. + + Returns: + A tuple containing (event_id, stream_id of the leave event). + """ + raise NotImplementedError() + + @abc.abstractmethod async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. @@ -144,6 +198,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: str, invitee_user_id: str): + """Ratelimit invites by room and by target user. + """ + self._invites_per_room_limiter.ratelimit(room_id) + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -279,6 +339,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -319,6 +380,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed=third_party_signed, ratelimit=ratelimit, content=content, + new_room=new_room, require_consent=require_consent, ) @@ -335,6 +397,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, ) -> Tuple[str, int]: """Helper for update_membership. @@ -387,8 +450,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -411,8 +478,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) block_invite = True + is_published = await self.store.is_room_published(room_id) + if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), + target_id, + third_party_invite=None, + room_id=room_id, + new_room=new_room, + published_room=is_published, ): logger.info("Blocking invite due to spam checker") block_invite = True @@ -490,6 +564,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to join rooms + is_requester_admin = True + + else: + is_requester_admin = await self.auth.is_server_admin(requester.user) + + inviter = await self._get_inviter(target.to_string(), room_id) + if not is_requester_admin: + # We assume that if the spam checker allowed the user to create + # a room then they're allowed to join it. + if not new_room and not self.spam_checker.user_may_join_room( + target.to_string(), room_id, is_invited=inviter is not None + ): + raise SynapseError(403, "Not allowed to join this room") + if not is_host_in_room: if ratelimit: time_now_s = self.clock.time() @@ -527,50 +620,76 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: - # perhaps we've been invited + # Figure out the user's current membership state for the room ( current_membership_type, current_membership_event_id, ) = await self.store.get_local_current_membership_for_user_in_room( target.to_string(), room_id ) - if ( - current_membership_type != Membership.INVITE - or not current_membership_event_id - ): + if not current_membership_type or not current_membership_event_id: logger.info( "%s sent a leave request to %s, but that is not an active room " - "on this server, and there is no pending invite", + "on this server, or there is no pending invite or knock", target, room_id, ) raise SynapseError(404, "Not a known room") - invite = await self.store.get_event(current_membership_event_id) - logger.info( - "%s rejects invite to %s from %s", target, room_id, invite.sender - ) + # perhaps we've been invited + if current_membership_type == Membership.INVITE: + invite = await self.store.get_event(current_membership_event_id) + logger.info( + "%s rejects invite to %s from %s", + target, + room_id, + invite.sender, + ) + + if not self.hs.is_mine_id(invite.sender): + # send the rejection to the inviter's HS (with fallback to + # local event) + return await self.remote_reject_invite( + invite.event_id, txn_id, requester, content, + ) - if not self.hs.is_mine_id(invite.sender): - # send the rejection to the inviter's HS (with fallback to - # local event) - return await self.remote_reject_invite( - invite.event_id, txn_id, requester, content, + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath, which will also send the + # rejection out to any other servers we believe are still in the room. + + # thanks to overzealous cleaning up of event_forward_extremities in + # `delete_old_current_state_events`, it's possible to end up with no + # forward extremities here. If that happens, let's just hang the + # rejection off the invite event. + # + # see: https://github.com/matrix-org/synapse/issues/7139 + if len(latest_event_ids) == 0: + latest_event_ids = [invite.event_id] + + # or perhaps this is a remote room that a local user has knocked on + elif current_membership_type == Membership.KNOCK: + knock = await self.store.get_event(current_membership_event_id) + return await self.remote_rescind_knock( + knock.event_id, txn_id, requester, content ) - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath, which will also send the - # rejection out to any other servers we believe are still in the room. + elif effective_membership_state == Membership.KNOCK: + if not is_host_in_room: + # The knock needs to be sent over federation instead + remote_room_hosts.append(get_domain_from_id(room_id)) - # thanks to overzealous cleaning up of event_forward_extremities in - # `delete_old_current_state_events`, it's possible to end up with no - # forward extremities here. If that happens, let's just hang the - # rejection off the invite event. - # - # see: https://github.com/matrix-org/synapse/issues/7139 - if len(latest_event_ids) == 0: - latest_event_ids = [invite.event_id] + content["membership"] = Membership.KNOCK + + profile = self.profile_handler + if "displayname" not in content: + content["displayname"] = await profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = await profile.get_avatar_url(target) + + return await self.remote_knock( + remote_room_hosts, room_id, target, content + ) return await self._local_membership_update( requester=requester, @@ -786,6 +905,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): id_server: str, requester: Requester, txn_id: Optional[str], + new_room: bool = False, id_access_token: Optional[str] = None, ) -> int: """Invite a 3PID to a room. @@ -833,6 +953,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Codes.FORBIDDEN, ) + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( + medium, address, room_id + ) + if not can_invite: + raise SynapseError( + 403, + "This third-party identifier can not be invited in this room", + Codes.FORBIDDEN, + ) + if not self._enable_lookup: raise SynapseError( 403, "Looking up third-party identifiers is denied from this server" @@ -842,6 +972,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): id_server, medium, address, id_access_token ) + is_published = await self.store.is_room_published(room_id) + + if not await self.spam_checker.user_may_invite( + requester.user.to_string(), + invitee, + third_party_invite={"medium": medium, "address": address}, + room_id=room_id, + new_room=new_room, + published_room=is_published, + ): + logger.info("Blocking invite due to spam checker") + raise SynapseError(403, "Invites have been disabled on this server") + if invitee: # Note that update_membership with an action of "invite" can raise # a ShadowBanError, but this was done above already. @@ -1131,6 +1274,35 @@ class RoomMemberMasterHandler(RoomMemberHandler): invite_event, txn_id, requester, content ) + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """ + Rescinds a local knock made on a remote room + + Args: + knock_event_id: The ID of the knock event to rescind. + txn_id: The transaction ID to use. + requester: The originator of the request. + content: The content of the leave event. + + Implements RoomMemberHandler.remote_rescind_knock + """ + # TODO: We don't yet support rescinding knocks over federation + # as we don't know which homeserver to send it to. An obvious + # candidate is the remote homeserver we originally knocked through, + # however we don't currently store that information. + + # Just rescind the knock locally + knock_event = await self.store.get_event(knock_event_id) + return await self._generate_local_out_of_band_leave( + knock_event, txn_id, requester, content + ) + async def _generate_local_out_of_band_leave( self, previous_membership_event: EventBase, @@ -1191,6 +1363,32 @@ class RoomMemberMasterHandler(RoomMemberHandler): return result_event.event_id, result_event.internal_metadata.stream_ordering + async def remote_knock( + self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict, + ) -> Tuple[str, int]: + """Sends a knock to a room. Attempts to do so via one remote out of a given list. + + Args: + remote_room_hosts: A list of homeservers to try knocking through. + room_id: The ID of the room to knock on. + user: The user to knock on behalf of. + content: The content of the knock event. + + Returns: + A tuple of (event ID, stream ID). + """ + # filter ourselves out of remote_room_hosts + remote_room_hosts = [ + host for host in remote_room_hosts if host != self.hs.hostname + ] + + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + return await self.federation_handler.do_knock( + remote_room_hosts, room_id, user.to_string(), content=content + ) + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index f2e88f6a5b..3de63e885e 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,10 +21,12 @@ from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler from synapse.replication.http.membership import ( ReplicationRemoteJoinRestServlet as ReplRemoteJoin, + ReplicationRemoteKnockRestServlet as ReplRemoteKnock, ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, + ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID logger = logging.getLogger(__name__) @@ -33,7 +36,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) + self._remote_knock_client = ReplRemoteKnock.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs) + self._remote_rescind_client = ReplRescindKnock.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) async def _remote_join( @@ -79,6 +84,49 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) return ret["event_id"], ret["stream_id"] + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """ + Rescinds a local knock made on a remote room + + Args: + knock_event_id: the knock event + txn_id: optional transaction ID supplied by the client + requester: user making the request, according to the access token + content: additional content to include in the leave event. + Normally an empty dict. + + Returns: + A tuple containing (event_id, stream_id of the leave event) + """ + ret = await self._remote_rescind_client( + knock_event_id=knock_event_id, + txn_id=txn_id, + requester=requester, + content=content, + ) + return ret["event_id"], ret["stream_id"] + + async def remote_knock( + self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict, + ) -> Tuple[str, int]: + """Sends a knock to a room. + + Implements RoomMemberHandler.remote_knock + """ + ret = await self._remote_knock_client( + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user=user, + content=content, + ) + return ret["event_id"], ret["stream_id"] + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 66f1bbcfc4..94062e79cb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py
@@ -15,23 +15,28 @@ import itertools import logging -from typing import Iterable +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.storage.state import StateFilter +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() @@ -87,13 +92,15 @@ class SearchHandler(BaseHandler): return historical_room_ids - async def search(self, user, content, batch=None): + async def search( + self, user: UserID, content: JsonDict, batch: Optional[str] = None + ) -> JsonDict: """Performs a full text search for a user. Args: - user (UserID) - content (dict): Search parameters - batch (str): The next_batch parameter. Used for pagination. + user + content: Search parameters + batch: The next_batch parameter. Used for pagination. Returns: dict to be returned to the client with results of search @@ -186,7 +193,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] + historical_room_ids = [] # type: List[str] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -209,8 +216,10 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] - room_groups = {} # Holds result of grouping by room, if applicable - sender_group = {} # Holds result of grouping by sender, if applicable + # Holds result of grouping by room, if applicable + room_groups = {} # type: Dict[str, JsonDict] + # Holds result of grouping by sender, if applicable + sender_group = {} # type: Dict[str, JsonDict] # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -254,7 +263,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] + room_events = [] # type: List[EventBase] i = 0 pagination_token = batch_token @@ -418,13 +427,10 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = {e.room_id for e in allowed_events} - for room_id in rooms: + for room_id in {e.room_id for e in allowed_events}: state = await self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) - state_results.values() - # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong @@ -448,9 +454,9 @@ class SearchHandler(BaseHandler): if state_results: s = {} - for room_id, state in state_results.items(): + for room_id, state_events in state_results.items(): s[room_id] = await self._event_serializer.serialize_events( - state, time_now + state_events, time_now ) rooms_cat_res["state"] = s diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index a5d67f828f..cef6b3ae48 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,24 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - self._password_policy_handler = hs.get_password_policy_handler() async def set_password( self, @@ -38,7 +41,7 @@ class SetPasswordHandler(BaseHandler): password_hash: str, logout_devices: bool, requester: Optional[Requester] = None, - ): + ) -> None: if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index fb4f70e8e2..b3f9875358 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py
@@ -14,15 +14,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class StateDeltasHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change( + self, + prev_event_id: Optional[str], + event_id: Optional[str], + key_name: str, + public_value: str, + ) -> Optional[bool]: """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index dc62b21c06..0b5e62da1b 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py
@@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +14,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple + +from typing_extensions import Counter as CounterType from synapse.api.constants import EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +39,7 @@ class StatsHandler: Heavily derived from UserDirectoryHandler """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -44,7 +52,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None + self.pos = None # type: Optional[int] # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -56,7 +64,7 @@ class StatsHandler: # we start populating stats self.clock.call_later(0, self.notify_new_event) - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when there may be more deltas to process """ if not self.stats_enabled or self._is_processing: @@ -72,7 +80,7 @@ class StatsHandler: run_as_background_process("stats.notify_new_event", process) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_stats_positions() @@ -110,10 +118,10 @@ class StatsHandler: ) for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, {}).update(fields) + room_deltas.setdefault(room_id, Counter()).update(fields) for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, {}).update(fields) + user_deltas.setdefault(user_id, Counter()).update(fields) logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -131,19 +139,20 @@ class StatsHandler: self.pos = max_pos - async def _handle_deltas(self, deltas): + async def _handle_deltas( + self, deltas: Iterable[JsonDict] + ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]: """Called with the state deltas to process Returns: - tuple[dict[str, Counter], dict[str, counter]] Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} - user_to_stats_deltas = {} + room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - room_to_state_updates = {} + room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] for delta in deltas: typ = delta["type"] @@ -173,7 +182,7 @@ class StatsHandler: ) continue - event_content = {} + event_content = {} # type: JsonDict sender = None if event_id is not None: @@ -225,6 +234,8 @@ class StatsHandler: room_stats_delta["left_members"] -= 1 elif prev_membership == Membership.BAN: room_stats_delta["banned_members"] -= 1 + elif prev_membership == Membership.KNOCK: + room_stats_delta["knocked_members"] -= 1 else: raise ValueError( "%r is not a valid prev_membership" % (prev_membership,) @@ -246,6 +257,8 @@ class StatsHandler: room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: room_stats_delta["banned_members"] += 1 + elif membership == Membership.KNOCK: + room_stats_delta["knocked_members"] += 1 else: raise ValueError("%r is not a valid membership" % (membership,)) @@ -257,13 +270,13 @@ class StatsHandler: ) if has_changed_joinedness: - delta = +1 if membership == Membership.JOIN else -1 + membership_delta = +1 if membership == Membership.JOIN else -1 user_to_stats_deltas.setdefault(user_id, Counter())[ "joined_rooms" - ] += delta + ] += membership_delta - room_stats_delta["local_users_in_room"] += delta + room_stats_delta["local_users_in_room"] += membership_delta elif typ == EventTypes.Create: room_state["is_federatable"] = ( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5c7590f38e..e8947e0f9b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -151,6 +151,16 @@ class InvitedSyncResult: @attr.s(slots=True, frozen=True) +class KnockedSyncResult: + room_id = attr.ib(type=str) + knock = attr.ib(type=EventBase) + + def __bool__(self) -> bool: + """Knocked rooms should always be reported to the client""" + return True + + +@attr.s(slots=True, frozen=True) class GroupsSyncResult: join = attr.ib(type=JsonDict) invite = attr.ib(type=JsonDict) @@ -183,6 +193,7 @@ class _RoomChanges: room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) invited = attr.ib(type=List[InvitedSyncResult]) + knocked = attr.ib(type=List[KnockedSyncResult]) newly_joined_rooms = attr.ib(type=List[str]) newly_left_rooms = attr.ib(type=List[str]) @@ -196,6 +207,7 @@ class SyncResult: account_data: List of account_data events for the user. joined: JoinedSyncResult for each joined room. invited: InvitedSyncResult for each invited room. + knocked: KnockedSyncResult for each knocked on room. archived: ArchivedSyncResult for each archived room. to_device: List of direct messages for the device. device_lists: List of user_ids whose devices have changed @@ -211,6 +223,7 @@ class SyncResult: account_data = attr.ib(type=List[JsonDict]) joined = attr.ib(type=List[JoinedSyncResult]) invited = attr.ib(type=List[InvitedSyncResult]) + knocked = attr.ib(type=List[KnockedSyncResult]) archived = attr.ib(type=List[ArchivedSyncResult]) to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) @@ -227,6 +240,7 @@ class SyncResult: self.presence or self.joined or self.invited + or self.knocked or self.archived or self.account_data or self.to_device @@ -999,7 +1013,7 @@ class SyncHandler: res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_or_invited_users, _, _ = res + newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( @@ -1008,7 +1022,9 @@ class SyncHandler: if self.hs_config.use_presence and not block_all_presence_data: logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, ) logger.debug("Fetching to-device data") @@ -1017,7 +1033,7 @@ class SyncHandler: device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, - newly_joined_or_invited_users=newly_joined_or_invited_users, + newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, newly_left_rooms=newly_left_rooms, newly_left_users=newly_left_users, ) @@ -1051,6 +1067,7 @@ class SyncHandler: account_data=sync_result_builder.account_data, joined=sync_result_builder.joined, invited=sync_result_builder.invited, + knocked=sync_result_builder.knocked, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, @@ -1110,7 +1127,7 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", newly_joined_rooms: Set[str], - newly_joined_or_invited_users: Set[str], + newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], ) -> DeviceLists: @@ -1119,8 +1136,9 @@ class SyncHandler: Args: sync_result_builder newly_joined_rooms: Set of rooms user has joined since previous sync - newly_joined_or_invited_users: Set of users that have joined or - been invited to a room since previous sync. + newly_joined_or_invited_or_knocked_users: Set of users that have joined, + been invited to a room or are knocking on a room since + previous sync. newly_left_rooms: Set of rooms user has left since previous sync newly_left_users: Set of users that have left a room we're in since previous sync @@ -1131,7 +1149,9 @@ class SyncHandler: # We're going to mutate these fields, so lets copy them rather than # assume they won't get used later. - newly_joined_or_invited_users = set(newly_joined_or_invited_users) + newly_joined_or_invited_or_knocked_users = set( + newly_joined_or_invited_or_knocked_users + ) newly_left_users = set(newly_left_users) if since_token and since_token.device_list_key: @@ -1170,11 +1190,11 @@ class SyncHandler: # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: joined_users = await self.state.get_current_users_in_room(room_id) - newly_joined_or_invited_users.update(joined_users) + newly_joined_or_invited_or_knocked_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. - users_that_have_changed.update(newly_joined_or_invited_users) + users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) user_signatures_changed = await self.store.get_users_whose_signatures_changed( user_id, since_token.device_list_key @@ -1419,6 +1439,7 @@ class SyncHandler: room_entries = room_changes.room_entries invited = room_changes.invited + knocked = room_changes.knocked newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms @@ -1439,9 +1460,10 @@ class SyncHandler: await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) + sync_result_builder.knocked.extend(knocked) - # Now we want to get any newly joined or invited users - newly_joined_or_invited_users = set() + # Now we want to get any newly joined, invited or knocking users + newly_joined_or_invited_or_knocked_users = set() newly_left_users = set() if since_token: for joined_sync in sync_result_builder.joined: @@ -1453,19 +1475,22 @@ class SyncHandler: if ( event.membership == Membership.JOIN or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK ): - newly_joined_or_invited_users.add(event.state_key) + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) else: prev_content = event.unsigned.get("prev_content", {}) prev_membership = prev_content.get("membership", None) if prev_membership == Membership.JOIN: newly_left_users.add(event.state_key) - newly_left_users -= newly_joined_or_invited_users + newly_left_users -= newly_joined_or_invited_or_knocked_users return ( set(newly_joined_rooms), - newly_joined_or_invited_users, + newly_joined_or_invited_or_knocked_users, set(newly_left_rooms), newly_left_users, ) @@ -1521,6 +1546,7 @@ class SyncHandler: newly_left_rooms = [] room_entries = [] invited = [] + knocked = [] for room_id, events in mem_change_events_by_room_id.items(): logger.debug( "Membership changes in %s: [%s]", @@ -1600,9 +1626,17 @@ class SyncHandler: should_invite = non_joins[-1].membership == Membership.INVITE if should_invite: if event.sender not in ignored_users: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) + invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if invite_room_sync: + invited.append(invite_room_sync) + + # Only bother if our latest membership in the room is knock (and we haven't + # been accepted/rejected in the meantime). + should_knock = non_joins[-1].membership == Membership.KNOCK + if should_knock: + knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) + if knock_room_sync: + knocked.append(knock_room_sync) # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? @@ -1706,7 +1740,9 @@ class SyncHandler: ) room_entries.append(entry) - return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) + return _RoomChanges( + room_entries, invited, knocked, newly_joined_rooms, newly_left_rooms, + ) async def _get_all_rooms( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] @@ -1726,6 +1762,7 @@ class SyncHandler: membership_list = ( Membership.INVITE, + Membership.KNOCK, Membership.JOIN, Membership.LEAVE, Membership.BAN, @@ -1737,6 +1774,7 @@ class SyncHandler: room_entries = [] invited = [] + knocked = [] for event in room_list: if event.membership == Membership.JOIN: @@ -1756,8 +1794,11 @@ class SyncHandler: continue invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) + elif event.membership == Membership.KNOCK: + knock = await self.store.get_event(event.event_id) + knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock)) elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. + # Always send down rooms we were banned from or kicked from. if not sync_config.filter_collection.include_leave: if event.membership == Membership.LEAVE: if user_id == event.sender: @@ -1778,7 +1819,7 @@ class SyncHandler: ) ) - return _RoomChanges(room_entries, invited, [], []) + return _RoomChanges(room_entries, invited, knocked, [], []) async def _generate_room_entry( self, @@ -2067,6 +2108,7 @@ class SyncResultBuilder: account_data (list) joined (list[JoinedSyncResult]) invited (list[InvitedSyncResult]) + knocked (list[KnockedSyncResult]) archived (list[ArchivedSyncResult]) groups (GroupsSyncResult|None) to_device (list) @@ -2082,6 +2124,7 @@ class SyncResultBuilder: account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) + knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list)) archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) groups = attr.ib(type=Optional[GroupsSyncResult], default=None) to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index e919a8f9ed..3f0dfc7a74 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -15,13 +15,13 @@ import logging import random from collections import namedtuple -from typing import TYPE_CHECKING, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -65,17 +65,17 @@ class FollowerTypingHandler: ) # map room IDs to serial numbers - self._room_serials = {} + self._room_serials = {} # type: Dict[str, int] # map room IDs to sets of users currently typing - self._room_typing = {} + self._room_typing = {} # type: Dict[str, Set[str]] - self._member_last_federation_poke = {} + self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) - def _reset(self): + def _reset(self) -> None: """Reset the typing handler's data caches. """ # map room IDs to serial numbers @@ -86,7 +86,7 @@ class FollowerTypingHandler: self._member_last_federation_poke = {} self.wheel_timer = WheelTimer(bucket_size=5000) - def _handle_timeouts(self): + def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() @@ -96,7 +96,7 @@ class FollowerTypingHandler: for member in members: self._handle_timeout_for_member(now, member) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: if not self.is_typing(member): # Nothing to do if they're no longer typing return @@ -114,10 +114,10 @@ class FollowerTypingHandler: # each person typing. self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) - def is_typing(self, member): + def is_typing(self, member: RoomMember) -> bool: return member.user_id in self._room_typing.get(member.room_id, []) - async def _push_remote(self, member, typing): + async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: return @@ -148,7 +148,7 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: """Should be called whenever we receive updates for typing stream. """ @@ -178,7 +178,7 @@ class FollowerTypingHandler: async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] - ): + ) -> None: """Process a change in typing of a room from replication, sending EDUs for any local users. """ @@ -194,12 +194,12 @@ class FollowerTypingHandler: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), False) - def get_current_token(self): + def get_current_token(self) -> int: return self._latest_room_serial class TypingWriterHandler(FollowerTypingHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() @@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - self._member_typing_until = {} # clock time we expect to stop + # clock time we expect to stop + self._member_typing_until = {} # type: Dict[RoomMember, int] # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial ) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: super()._handle_timeout_for_member(now, member) if not self.is_typing(member): @@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) return - async def started_typing(self, target_user, requester, room_id, timeout): + async def started_typing( + self, target_user: UserID, requester: Requester, room_id: str, timeout: int + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler): if was_present: # No point sending another notification - return None + return self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, requester, room_id): + async def stopped_typing( + self, target_user: UserID, requester: Requester, room_id: str + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) - def user_left_room(self, user, room_id): + def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) self._stopped_typing(member) - def _stopped_typing(self, member): + def _stopped_typing(self, member: RoomMember) -> None: if member.user_id not in self._room_typing.get(member.room_id, set()): # No point - return None + return self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) self._push_update(member=member, typing=False) - def _push_update(self, member, typing): + def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process( @@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler): self._push_update_local(member=member, typing=typing) - async def _recv_edu(self, origin, content): + async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] @@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler): self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self._push_update_local(member=member, typing=content["typing"]) - def _push_update_local(self, member, typing): + def _push_update_local(self, member: RoomMember, typing: bool) -> None: room_set = self._room_typing.setdefault(member.room_id, set()) if typing: room_set.add(member.user_id) @@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler): changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id - ) + ) # type: Optional[Iterable[str]] if changed_rooms is None: changed_rooms = self._room_serials @@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler): def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: # The writing process should never get updates from replication. raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: @@ -427,7 +432,7 @@ class TypingNotificationEventSource: # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id): + def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", @@ -462,7 +467,9 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: Iterable[str], **kwargs + ) -> Tuple[List[JsonDict], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -478,5 +485,5 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - def get_current_key(self): + def get_current_key(self) -> int: return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d4651c8348..8aedf5072e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler): if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() - # If still None then the initial background update hasn't happened yet - if self.pos is None: - return None - # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): @@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler): if change: # The user joined event = await self.store.get_event(event_id, allow_none=True) + # It isn't expected for this event to not exist, but we + # don't want the entire background process to break. + if event is None: + continue + profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37ccf5ab98..8eb93ba73e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -289,8 +289,7 @@ class SimpleHttpClient: treq_args: Dict[str, Any] = {}, ip_whitelist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None, - http_proxy: Optional[bytes] = None, - https_proxy: Optional[bytes] = None, + use_proxy: bool = False, ): """ Args: @@ -300,8 +299,8 @@ class SimpleHttpClient: we may not request. ip_whitelist: The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. - http_proxy: proxy server to use for http connections. host[:port] - https_proxy: proxy server to use for https connections. host[:port] + use_proxy: Whether proxy settings should be discovered and used + from conventional environment variables. """ self.hs = hs @@ -345,8 +344,7 @@ class SimpleHttpClient: connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, - http_proxy=http_proxy, - https_proxy=https_proxy, + use_proxy=use_proxy, ) if self._ip_blacklist: diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import ConnectError -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.internet.protocol import connectionDone +from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http +from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint: Args: reactor: the Twisted reactor to use for the connection - proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the - proxy - host (bytes): hostname that we want to CONNECT to - port (int): port that we want to connect to + proxy_endpoint: the endpoint to use to connect to the proxy + host: hostname that we want to CONNECT to + port: port that we want to connect to + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, reactor, proxy_endpoint, host, port): + def __init__( + self, + reactor: IReactorCore, + proxy_endpoint: IStreamClientEndpoint, + host: bytes, + port: int, + headers: Headers, + ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port + self._headers = headers def __repr__(self): return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) - def connect(self, protocolFactory): - f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + def connect(self, protocolFactory: ClientFactory): + f = HTTPProxiedClientFactory( + self._host, self._port, protocolFactory, self._headers + ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the # CONNECT to complete. @@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): HTTP Protocol object and run the rest of the connection. Args: - dst_host (bytes): hostname that we want to CONNECT to - dst_port (int): port that we want to connect to - wrapped_factory (protocol.ClientFactory): The original Factory + dst_host: hostname that we want to CONNECT to + dst_port: port that we want to connect to + wrapped_factory: The original Factory + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, dst_host, dst_port, wrapped_factory): + def __init__( + self, + dst_host: bytes, + dst_port: int, + wrapped_factory: ClientFactory, + headers: Headers, + ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory + self.headers = headers self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): wrapped_protocol = self.wrapped_factory.buildProtocol(addr) return HTTPConnectProtocol( - self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + self.dst_host, + self.dst_port, + wrapped_protocol, + self.on_connection, + self.headers, ) def clientConnectionFailed(self, connector, reason): @@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol): """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect Args: - host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + host: The original HTTP(s) hostname or IPv4 or IPv6 address literal to put in the CONNECT request - port (int): The original HTTP(s) port to put in the CONNECT request + port: The original HTTP(s) port to put in the CONNECT request - wrapped_protocol (interfaces.IProtocol): the original protocol (probably - HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + wrapped_protocol: the original protocol (probably HTTPChannel or + TLSMemoryBIOProtocol, but could be anything really) - connected_deferred (Deferred): a Deferred which will be callbacked with + connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes + + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, host, port, wrapped_protocol, connected_deferred): + def __init__( + self, + host: bytes, + port: int, + wrapped_protocol: Protocol, + connected_deferred: defer.Deferred, + headers: Headers, + ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.headers = headers + + self.http_setup_client = HTTPConnectSetupClient( + self.host, self.port, self.headers + ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) def connectionMade(self): @@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol): if buf: self.wrapped_protocol.dataReceived(buf) - def dataReceived(self, data): + def dataReceived(self, data: bytes): # if we've set up the HTTP protocol, we can send the data there if self.wrapped_protocol.connected: return self.wrapped_protocol.dataReceived(data) @@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient): """HTTPClient protocol to send a CONNECT message for proxies and read the response. Args: - host (bytes): The hostname to send in the CONNECT message - port (int): The port to send in the CONNECT message + host: The hostname to send in the CONNECT message + port: The port to send in the CONNECT message + headers: Extra headers to send with the CONNECT message """ - def __init__(self, host, port): + def __init__(self, host: bytes, port: int, headers: Headers): self.host = host self.port = port + self.headers = headers self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + + # Send any additional specified headers + for name, values in self.headers.getAllRawHeaders(): + for value in values: + self.sendHeader(name, value) + self.endHeaders() - def handleStatus(self, version, status, message): + def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b730d2c634..ee65a6668b 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -12,9 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import re +from typing import Optional, Tuple +from urllib.request import getproxies_environment, proxy_bypass_environment +import attr from zope.interface import implementer from twisted.internet import defer @@ -22,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.python.failure import Failure from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.error import SchemeNotSupported +from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint @@ -31,6 +36,22 @@ logger = logging.getLogger(__name__) _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -58,6 +79,9 @@ class ProxyAgent(_AgentBase): pool (HTTPConnectionPool|None): connection pool to be used. If None, a non-persistent pool instance will be created. + + use_proxy (bool): Whether proxy settings should be discovered and used + from conventional environment variables. """ def __init__( @@ -68,8 +92,7 @@ class ProxyAgent(_AgentBase): connectTimeout=None, bindAddress=None, pool=None, - http_proxy=None, - https_proxy=None, + use_proxy=False, ): _AgentBase.__init__(self, reactor, pool) @@ -84,6 +107,18 @@ class ProxyAgent(_AgentBase): if bindAddress is not None: self._endpoint_kwargs["bindAddress"] = bindAddress + http_proxy = None + https_proxy = None + no_proxy = None + if use_proxy: + proxies = getproxies_environment() + http_proxy = proxies["http"].encode() if "http" in proxies else None + https_proxy = proxies["https"].encode() if "https" in proxies else None + no_proxy = proxies["no"] if "no" in proxies else None + + # Parse credentials from https proxy connection string if present + self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) + self.http_proxy_endpoint = _http_proxy_endpoint( http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) @@ -92,6 +127,8 @@ class ProxyAgent(_AgentBase): https_proxy, self.proxy_reactor, **self._endpoint_kwargs ) + self.no_proxy = no_proxy + self._policy_for_https = contextFactory self._reactor = reactor @@ -139,18 +176,43 @@ class ProxyAgent(_AgentBase): pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) request_path = parsed_uri.originForm - if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + should_skip_proxy = False + if self.no_proxy is not None: + should_skip_proxy = proxy_bypass_environment( + parsed_uri.host.decode(), proxies={"no": self.no_proxy}, + ) + + if ( + parsed_uri.scheme == b"http" + and self.http_proxy_endpoint + and not should_skip_proxy + ): # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: pool_key = ("http-proxy", self.http_proxy_endpoint) endpoint = self.http_proxy_endpoint request_path = uri - elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + elif ( + parsed_uri.scheme == b"https" + and self.https_proxy_endpoint + and not should_skip_proxy + ): + connect_headers = Headers() + + # Determine whether we need to set Proxy-Authorization headers + if self.https_proxy_creds: + # Set a Proxy-Authorization header + connect_headers.addRawHeader( + b"Proxy-Authorization", + self.https_proxy_creds.as_proxy_authorization_value(), + ) + endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, + headers=connect_headers, ) else: # not using a proxy @@ -179,12 +241,16 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint(proxy, reactor, **kwargs): +def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs): """Parses an http proxy setting and returns an endpoint for the proxy Args: - proxy (bytes|None): the proxy setting + proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>] + Note that compared to other apps, this function currently lacks support + for specifying a protocol schema (i.e. protocol://...). + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint Returns: @@ -194,16 +260,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs): if proxy is None: return None - # currently we only support hostname:port. Some apps also support - # protocol://<host>[:port], which allows a way of requiring a TLS connection to the - # proxy. - + # Parse the connection string host, port = parse_host_port(proxy, default_port=1080) return HostnameEndpoint(reactor, host, port, **kwargs) -def parse_host_port(hostport, default_port=None): - # could have sworn we had one of these somewhere else... +def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]: + """ + Parses the username and password from a proxy declaration e.g + username:password@hostname:port. + + Args: + proxy: The proxy connection string. + + Returns + An instance of ProxyCredentials and the proxy connection string with any credentials + stripped, i.e u:p@host:port -> host:port. If no credentials were found, the + ProxyCredentials instance is replaced with None. + """ + if proxy and b"@" in proxy: + # We use rsplit here as the password could contain an @ character + credentials, proxy_without_credentials = proxy.rsplit(b"@", 1) + return ProxyCredentials(credentials), proxy_without_credentials + + return None, proxy + + +def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]: + """ + Parse the hostname and port from a proxy connection byte string. + + Args: + hostport: The proxy connection string. Must be in the form 'host[:port]'. + default_port: The default port to return if one is not found in `hostport`. + + Returns: + A tuple containing the hostname and port. Uses `default_port` if one was not found. + """ if b":" in hostport: host, port = hostport.rsplit(b":", 1) try: diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b361b7cbaf..9bfe151b5f 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -14,8 +14,8 @@ # limitations under the License. """ This module contains base REST classes for constructing REST servlets. """ - import logging +from typing import Dict, List, Optional, Union from synapse.api.errors import Codes, SynapseError from synapse.util import json_decoder @@ -147,16 +147,67 @@ def parse_string( ) +def parse_list_from_args( + args: Dict[bytes, List[bytes]], + name: Union[bytes, str], + encoding: Optional[str] = "ascii", +): + """Parse and optionally decode a list of values from request query parameters. + + Args: + args: A dictionary of query parameters from a request. + name: The name of the query parameter to extract values from. If given as bytes, + will be decoded as "ascii". + encoding: An optional encoding that is used to decode each parameter value with. + + Raises: + KeyError: If the given `name` does not exist in `args`. + SynapseError: If an argument was not encoded with the specified `encoding`. + """ + if not isinstance(name, bytes): + name = name.encode("ascii") + args_list = args[name] + + if encoding: + # Decode each argument value + try: + args_list = [value.decode(encoding) for value in args_list] + except ValueError: + raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + + return args_list + + def parse_string_from_args( - args, - name, - default=None, - required=False, - allowed_values=None, - param_type="string", - encoding="ascii", + args: Dict[bytes, List[bytes]], + name: Union[bytes, str], + default: Optional[str] = None, + required: Optional[bool] = False, + allowed_values: Optional[List[bytes]] = None, + param_type: Optional[str] = "string", + encoding: Optional[str] = "ascii", ): + """Parse and optionally decode a single value from request query parameters. + Args: + args: A dictionary of query parameters from a request. + name: The name of the query parameter to extract values from. If given as bytes, + will be decoded as "ascii". + default: A default value to return if the given argument `name` was not found. + required: If this is True, no `default` is provided and the given argument `name` + was not found then a SynapseError is raised. + allowed_values: A list of allowed values. If specified and the found str is + not in this list, a SynapseError is raised. + param_type: The expected type of the query parameter's value. + encoding: An optional encoding that is used to decode each parameter value with. + + Returns: + The found argument value. + + Raises: + SynapseError: If the given name was not found in the request arguments, + the argument's values were encoded incorrectly or a required value was missing. + """ if not isinstance(name, bytes): name = name.encode("ascii") diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index ab586c318c..0538350f38 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -791,7 +791,7 @@ def tag_args(func): @wraps(func) def _tag_args_inner(*args, **kwargs): - argspec = inspect.getargspec(func) + argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): set_tag("ARG_" + arg, args[i]) set_tag("args", args[len(argspec.args) :]) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6211506990..4d284de133 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py
@@ -495,7 +495,11 @@ BASE_APPEND_UNDERRIDE_RULES = [ "_id": "_message", } ], - "actions": ["notify", {"set_tweak": "highlight", "value": False}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... @@ -509,7 +513,11 @@ BASE_APPEND_UNDERRIDE_RULES = [ "_id": "_encrypted", } ], - "actions": ["notify", {"set_tweak": "highlight", "value": False}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, { "rule_id": "global/underride/.im.vector.jitsi", diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 4d875dcb91..745b1dde94 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -668,6 +668,15 @@ class Mailer: def safe_markup(raw_html: str) -> jinja2.Markup: + """ + Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs. + + Args + raw_html: Unsafe HTML. + + Returns: + A Markup object ready to safely use in a Jinja template. + """ return jinja2.Markup( bleach.linkify( bleach.clean( @@ -684,8 +693,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup: def safe_text(raw_text: str) -> jinja2.Markup: """ - Process text: treat it as HTML but escape any tags (ie. just escape the - HTML) then linkify it. + Sanitise text (escape any HTML tags), and then linkify any bare URLs. + + Args + raw_text: Unsafe text which might include HTML markup. + + Returns: + A Markup object ready to safely use in a Jinja template. """ return jinja2.Markup( bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 7e50341d74..04c2c1482c 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py
@@ -17,7 +17,7 @@ import logging import re from typing import TYPE_CHECKING, Dict, Iterable, Optional -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase from synapse.types import StateMap @@ -63,7 +63,7 @@ async def calculate_room_name( m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) - if m_room_name and m_room_name.content and m_room_name.content["name"]: + if m_room_name and m_room_name.content and m_room_name.content.get("name"): return m_room_name.content["name"] # does it have a canonical alias? @@ -74,15 +74,11 @@ async def calculate_room_name( if ( canon_alias and canon_alias.content - and canon_alias.content["alias"] + and canon_alias.content.get("alias") and _looks_like_an_alias(canon_alias.content["alias"]) ): return canon_alias.content["alias"] - # at this point we're going to need to search the state by all state keys - # for an event type, so rearrange the data structure - room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) - if not fallback_to_members: return None @@ -94,7 +90,7 @@ async def calculate_room_name( if ( my_member_event is not None - and my_member_event.content["membership"] == "invite" + and my_member_event.content.get("membership") == Membership.INVITE ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: inviter_member_event = await store.get_event( @@ -111,6 +107,10 @@ async def calculate_room_name( else: return "Room Invite" + # at this point we're going to need to search the state by all state keys + # for an event type, so rearrange the data structure + room_state_bytype_ids = _state_as_two_level_dict(room_state_ids) + # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: @@ -120,8 +120,8 @@ async def calculate_room_name( all_members = [ ev for ev in member_events.values() - if ev.content["membership"] == "join" - or ev.content["membership"] == "invite" + if ev.content.get("membership") == Membership.JOIN + or ev.content.get("membership") == Membership.INVITE ] # Sort the member events oldest-first so the we name people in the # order the joined (it should at least be deterministic rather than @@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str: def name_from_member_event(member_event: EventBase) -> str: - if ( - member_event.content - and "displayname" in member_event.content - and member_event.content["displayname"] - ): + if member_event.content and member_event.content.get("displayname"): return member_event.content["displayname"] return member_event.state_key diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index eed16dbfb5..3e843c97fe 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity = hs.config.account_validity + self._account_validity_enabled = hs.config.account_validity_enabled # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config @@ -225,7 +225,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) @@ -255,7 +255,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index bfd46a3730..60e6793c8d 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -112,6 +112,8 @@ CONDITIONAL_REQUIREMENTS = { "redis": ["txredisapi>=1.4.7", "hiredis"], } +CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"] + ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 84e002f934..84afc4c1e4 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py
@@ -98,6 +98,73 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): return 200, {"event_id": event_id, "stream_id": stream_id} +class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): + """Perform a remote knock for the given user on the given room + + Request format: + + POST /_synapse/replication/remote_knock/:room_id/:user_id + + { + "requester": ..., + "remote_room_hosts": [...], + "content": { ... } + } + """ + + NAME = "remote_knock" + PATH_ARGS = ("room_id", "user_id") + + def __init__(self, hs): + super().__init__(hs) + + self.federation_handler = hs.get_federation_handler() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore + requester: Requester, + room_id: str, + user_id: str, + remote_room_hosts: List[str], + content: JsonDict, + ): + """ + Args: + requester: The user making the request, according to the access token. + room_id: The ID of the room to knock on. + user_id: The ID of the knocking user. + remote_room_hosts: Servers to try and send the knock via. + content: The event content to use for the knock event. + """ + return { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "content": content, + } + + async def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str, + ): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + request.requester = requester + + logger.debug("remote_knock: %s on room: %s", user_id, room_id) + + event_id, stream_id = await self.federation_handler.do_knock( + remote_room_hosts, room_id, user_id, event_content + ) + + return 200, {"event_id": event_id, "stream_id": stream_id} + + class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): """Rejects an out-of-band invite we have received from a remote server @@ -166,6 +233,70 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): return 200, {"event_id": event_id, "stream_id": stream_id} +class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): + """Rescinds a local knock made on a remote room + + Request format: + + POST /_synapse/replication/remote_rescind_knock/:event_id + + { + "txn_id": ..., + "requester": ..., + "content": { ... } + } + """ + + NAME = "remote_rescind_knock" + PATH_ARGS = ("knock_event_id",) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.member_handler = hs.get_room_member_handler() + + @staticmethod + async def _serialize_payload( # type: ignore + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ): + """ + Args: + knock_event_id: The ID of the knock to be rescinded. + txn_id: An optional transaction ID supplied by the client. + requester: The user making the rescind request, according to the access token. + content: The content to include in the rescind event. + """ + return { + "txn_id": txn_id, + "requester": requester.serialize(), + "content": content, + } + + async def _handle_request( # type: ignore + self, request: Request, knock_event_id: str, + ): + content = parse_json_object_from_request(request) + + txn_id = content["txn_id"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + request.requester = requester + + # hopefully we're now on the master, so this won't recurse! + event_id, stream_id = await self.member_handler.remote_rescind_knock( + knock_event_id, txn_id, requester, event_content, + ) + + return 200, {"event_id": event_id, "stream_id": stream_id} + + class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): """Notifies that a user has joined or left the room diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644
index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@ # limitations under the License. import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Dict, @@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import ( TypingStream, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -88,7 +92,7 @@ class ReplicationCommandHandler: back out to connections. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastore() @@ -282,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -299,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - reactor=hs.get_reactor(), - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Type, cast import txredisapi @@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, + wrap_as_background_process, ) from synapse.replication.tcp.commands import ( Command, @@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): immediately after initialisation. Attributes: - handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to and publish from - (not anything to do with Synapse replication streams). - outbound_redis_connection: The connection to redis to use to send + synapse_handler: The command handler to handle incoming commands. + synapse_stream_name: The *redis* stream name to subscribe to and publish + from (not anything to do with Synapse replication streams). + synapse_outbound_redis_connection: The connection to redis to use to send commands. """ - handler = None # type: ReplicationCommandHandler - stream_name = None # type: str - outbound_redis_connection = None # type: txredisapi.RedisProtocol + synapse_handler = None # type: ReplicationCommandHandler + synapse_stream_name = None # type: str + synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. - logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) - await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name) + await make_deferred_yieldable(self.subscribe(self.synapse_stream_name)) logger.info( "Successfully subscribed to redis stream, sending REPLICATE command" ) - self.handler.new_connection(self) + self.synapse_handler.new_connection(self) await self._async_send_command(ReplicateCommand()) logger.info("REPLICATE successfully sent") # We send out our positions when there is a new connection in case the # other side missed updates. We do this for Redis connections as the # otherside won't know we've connected and so won't issue a REPLICATE. - self.handler.send_positions_to_connection(self) + self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. @@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): cmd: received command """ - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None) if not cmd_func: logger.warning("Unhandled command: %r", cmd) return @@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) - self.handler.lost_connection(self) + self.synapse_handler.lost_connection(self) # mark the logging context as finished self._logging_context.__exit__(None, None, None) @@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() await make_deferred_yieldable( - self.outbound_redis_connection.publish(self.stream_name, encoded_string) + self.synapse_outbound_redis_connection.publish( + self.synapse_stream_name, encoded_string + ) + ) + + +class SynapseRedisFactory(txredisapi.RedisFactory): + """A subclass of RedisFactory that periodically sends pings to ensure that + we detect dead connections. + """ + + def __init__( + self, + hs: "HomeServer", + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = txredisapi.ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: int = 30, + convertNumbers: Optional[int] = True, + ): + super().__init__( + uuid=uuid, + dbid=dbid, + poolsize=poolsize, + isLazy=isLazy, + handler=handler, + charset=charset, + password=password, + replyTimeout=replyTimeout, + convertNumbers=convertNumbers, ) + hs.get_clock().looping_call(self._send_ping, 30 * 1000) + + @wrap_as_background_process("redis_ping") + async def _send_ping(self): + for connection in self.pool: + try: + await make_deferred_yieldable(connection.ping()) + except Exception: + logger.warning("Failed to send ping to a redis connection") -class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + +class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately subscribes to a stream. @@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol ): - super().__init__() - - # This sets the password on the RedisFactory base class (as - # SubscriberFactory constructor doesn't pass it through). - self.password = hs.config.redis.redis_password + super().__init__( + hs, + uuid="subscriber", + dbid=None, + poolsize=1, + replyTimeout=30, + password=hs.config.redis.redis_password, + ) - self.handler = hs.get_tcp_replication() - self.stream_name = hs.hostname + self.synapse_handler = hs.get_tcp_replication() + self.synapse_stream_name = hs.hostname - self.outbound_redis_connection = outbound_redis_connection + self.synapse_outbound_redis_connection = outbound_redis_connection def buildProtocol(self, addr): - p = super().buildProtocol(addr) # type: RedisSubscriber + p = super().buildProtocol(addr) + p = cast(RedisSubscriber, p) # We do this here rather than add to the constructor of `RedisSubcriber` # as to do so would involve overriding `buildProtocol` entirely, however # the base method does some other things than just instantiating the # protocol. - p.handler = self.handler - p.outbound_redis_connection = self.outbound_redis_connection - p.stream_name = self.stream_name - p.password = self.password + p.synapse_handler = self.synapse_handler + p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection + p.synapse_stream_name = self.synapse_stream_name return p def lazyConnection( - reactor, + hs: "HomeServer", host: str = "localhost", port: int = 6379, dbid: Optional[int] = None, reconnect: bool = True, - charset: str = "utf-8", password: Optional[str] = None, - connectTimeout: Optional[int] = None, - replyTimeout: Optional[int] = None, - convertNumbers: bool = True, + replyTimeout: int = 30, ) -> txredisapi.RedisProtocol: - """Equivalent to `txredisapi.lazyConnection`, except allows specifying a - reactor. + """Creates a connection to Redis that is lazily set up and reconnects if the + connections is lost. """ - isLazy = True - poolsize = 1 - uuid = "%s:%d" % (host, port) - factory = txredisapi.RedisFactory( - uuid, - dbid, - poolsize, - isLazy, - txredisapi.ConnectionHandler, - charset, - password, - replyTimeout, - convertNumbers, + factory = SynapseRedisFactory( + hs, + uuid=uuid, + dbid=dbid, + poolsize=1, + isLazy=True, + handler=txredisapi.ConnectionHandler, + password=password, + replyTimeout=replyTimeout, ) factory.continueTrying = reconnect - for x in range(poolsize): - reactor.connectTCP(host, port, factory, connectTimeout) + + reactor = hs.get_reactor() + reactor.connectTCP(host, port, factory, 30) return factory.handler diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html new file mode 100644
index 0000000000..b751359bdf --- /dev/null +++ b/synapse/res/templates/account_previously_renewed.html
@@ -0,0 +1 @@ +<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html> diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index 894da030af..e8c0f52f05 100644 --- a/synapse/res/templates/account_renewed.html +++ b/synapse/res/templates/account_renewed.html
@@ -1 +1 @@ -<html><body>Your account has been successfully renewed.</body><html> +<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html> diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index a75c73a142..0f704b0818 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -12,7 +12,7 @@ <header> <h1>That doesn't look right</h1> <p> - <strong>We were unable to validate your {{ server_name | e }} account</strong> + <strong>We were unable to validate your {{ server_name }} account</strong> via single&nbsp;sign&#8209;on&nbsp;(SSO), because the SSO Identity Provider returned different details than when you logged in. </p> diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index 69b93d65c1..f127804823 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html
@@ -22,7 +22,7 @@ <header> <h1>There was an error</h1> <p> - <strong id="errormsg">{{ error_description | e }}</strong> + <strong id="errormsg">{{ error_description }}</strong> </p> <p> If you are seeing this page after clicking a link sent to you via email, diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 5b38481012..62a640dad2 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -3,22 +3,22 @@ <head> <meta charset="UTF-8"> <link rel="stylesheet" href="/_matrix/static/client/login/style.css"> - <title>{{server_name | e}} Login</title> + <title>{{ server_name }} Login</title> </head> <body> <div id="container"> - <h1 id="title">{{server_name | e}} Login</h1> + <h1 id="title">{{ server_name }} Login</h1> <div class="login_flow"> <p>Choose one of the following identity providers:</p> <form> - <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}"> + <input type="hidden" name="redirectUrl" value="{{ redirect_url }}"> <ul class="radiobuttons"> {% for p in providers %} <li> - <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}"> - <label for="prov{{loop.index}}">{{p.idp_name | e}}</label> + <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}"> + <label for="prov{{ loop.index }}">{{ p.idp_name }}</label> {% if p.idp_icon %} - <img src="{{p.idp_icon | mxc_to_http(32, 32)}}"/> + <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/> {% endif %} </li> {% endfor %} diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index ce4f573848..d1328a6969 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -12,11 +12,11 @@ <header> {% if new_user %} <h1>Your account is now ready</h1> - <p>You've made your account on {{ server_name | e }}.</p> + <p>You've made your account on {{ server_name }}.</p> {% else %} <h1>Log in</h1> {% endif %} - <p>Continue to confirm you trust <strong>{{ display_url | e }}</strong>.</p> + <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p> </header> <main> {% if user_profile.avatar_url %} @@ -24,13 +24,13 @@ <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> <div class="profile-details"> {% if user_profile.display_name %} - <div class="display-name">{{ user_profile.display_name | e }}</div> + <div class="display-name">{{ user_profile.display_name }}</div> {% endif %} - <div class="user-id">{{ user_id | e }}</div> + <div class="user-id">{{ user_id }}</div> </div> </div> {% endif %} - <a href="{{ redirect_url | e }}" class="primary-button">Continue</a> + <a href="{{ redirect_url }}" class="primary-button">Continue</a> </main> </body> </html> diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 40f5c32db2..ee3a9af569 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -39,6 +39,7 @@ from synapse.rest.client.v2_alpha import ( filter, groups, keys, + knock, notifications, openid, password_policy, @@ -119,8 +120,10 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) password_policy.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin admin.register_servlets_for_client_rest_resource(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..57e0a10837 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018-2019 New Vector Ltd +# Copyright 2020, 2021 The Matrix.org Foundation C.I.C. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet from synapse.rest.admin.rooms import ( DeleteRoomRestServlet, + ForwardExtremitiesRestServlet, JoinRoomAliasServlet, ListRoomRestServlet, MakeRoomAdminRestServlet, @@ -51,6 +54,7 @@ from synapse.rest.admin.users import ( PushersRestServlet, ResetPasswordRestServlet, SearchUsersRestServlet, + ShadowBanRestServlet, UserAdminServlet, UserMediaRestServlet, UserMembershipRestServlet, @@ -230,6 +234,8 @@ def register_servlets(hs, http_server): EventReportsRestServlet(hs).register(http_server) PushersRestServlet(hs).register(http_server) MakeRoomAdminRestServlet(hs).register(http_server) + ShadowBanRestServlet(hs).register(http_server) + ForwardExtremitiesRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..fcbb91f736 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Tuple from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError @@ -23,6 +23,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_integer, parse_json_object_from_request, + parse_list_from_args, parse_string, ) from synapse.http.site import SynapseRequest @@ -323,10 +324,8 @@ class JoinRoomAliasServlet(RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = [ - x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] - except Exception: + remote_room_hosts = parse_list_from_args(request.args, "server_name") + except KeyError: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler @@ -431,7 +430,17 @@ class MakeRoomAdminRestServlet(RestServlet): if not admin_users: raise SynapseError(400, "No local admin user in room") - admin_user_id = admin_users[-1] + admin_user_id = None + + for admin_user in reversed(admin_users): + if room_state.get((EventTypes.Member, admin_user)): + admin_user_id = admin_user + break + + if not admin_user_id: + raise SynapseError( + 400, "No local admin user in room", + ) pl_content = power_levels.content else: @@ -499,3 +508,60 @@ class MakeRoomAdminRestServlet(RestServlet): ) return 200, {} + + +class ForwardExtremitiesRestServlet(RestServlet): + """Allows a server admin to get or clear forward extremities. + + Clearing does not require restarting the server. + + Clear forward extremities: + DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities + + Get forward_extremities: + GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.room_member_handler = hs.get_room_member_handler() + self.store = hs.get_datastore() + + async def resolve_room_id(self, room_identifier: str) -> str: + """Resolve to a room ID, if necessary.""" + if RoomID.is_valid(room_identifier): + resolved_room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias) + resolved_room_id = room_id.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + if not resolved_room_id: + raise SynapseError( + 400, "Unknown room ID or room alias %s" % room_identifier + ) + return resolved_room_id + + async def on_DELETE(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + deleted_count = await self.store.delete_forward_extremities_for_room(room_id) + return 200, {"deleted": deleted_count} + + async def on_GET(self, request, room_identifier): + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + room_id = await self.resolve_room_id(room_identifier) + + extremities = await self.store.get_forward_extremities_for_room(room_id) + return 200, {"count": len(extremities), "results": extremities} diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..68c3c64a0d 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet): The parameter `deactivated` can be used to include deactivated users. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - async def on_GET(self, request): + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) + + if start < 0: + raise SynapseError( + 400, + "Query parameter from must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "Query parameter limit must be a string representing a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + user_id = parse_string(request, "user_id", default=None) name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) @@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet): start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} - if len(users) >= limit: + if (start + limit) < total: ret["next_token"] = str(start + len(users)) return 200, ret @@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet): ) return 200, {"access_token": token} + + +class ShadowBanRestServlet(RestServlet): + """An admin API for shadow-banning a user. + + A shadow-banned users receives successful responses to their client-server + API requests, but the events are not propagated into rooms. + + Shadow-banning a user should be used as a tool of last resort and may lead + to confusing or broken behaviour for the client. + + Example: + + POST /_synapse/admin/v1/users/@test:example.com/shadow_ban + {} + + 200 OK + {} + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.hs.is_mine_id(user_id): + raise SynapseError(400, "Only local users can be shadow-banned") + + await self.store.set_shadow_banned(UserID.from_string(user_id), True) + + return 200, {} diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 23a529f8e3..94bfe2d1b0 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py
@@ -49,9 +49,7 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( - state, self.clock.time_msec(), include_user_id=False - ) + state = format_user_presence_state(state, self.clock.time_msec()) return 200, state diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..b5fa1cc464 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/<paths> """ +from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() + self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -65,8 +67,27 @@ class ProfileDisplaynameRestServlet(RestServlet): await self.profile_handler.set_displayname(user, requester, new_name, is_admin) + if self.hs.config.shadow_server: + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) + self.shadow_displayname(shadow_user.to_string(), content) + + return 200, {} + + def on_OPTIONS(self, request, user_id): return 200, {} + @defer.inlineCallbacks + def shadow_displayname(self, user_id, body): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.put_json( + "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, + ) + class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) @@ -75,6 +96,7 @@ class ProfileAvatarURLRestServlet(RestServlet): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() + self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -113,8 +135,27 @@ class ProfileAvatarURLRestServlet(RestServlet): user, requester, new_avatar_url, is_admin ) + if self.hs.config.shadow_server: + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) + self.shadow_avatar_url(shadow_user.to_string(), content) + + return 200, {} + + def on_OPTIONS(self, request, user_id): return 200, {} + @defer.inlineCallbacks + def shadow_avatar_url(self, user_id, body): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.put_json( + "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, + ) + class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..c8b128583f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -15,10 +15,9 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ - import logging import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership @@ -37,6 +36,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_integer, parse_json_object_from_request, + parse_list_from_args, parse_string, ) from synapse.logging.opentracing import set_tag @@ -283,10 +283,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = [ - x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] - except Exception: + remote_room_hosts = parse_list_from_args(request.args, "server_name") + except KeyError: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler @@ -739,7 +737,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): content["id_server"], requester, txn_id, - content.get("id_access_token"), + new_room=False, + id_access_token=content.get("id_access_token"), ) except ShadowBanError: # Pretend the request succeeded. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..aa170c215f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ # limitations under the License. import logging import random +import re from http import HTTPStatus from typing import TYPE_CHECKING from urllib.parse import urlparse @@ -38,6 +39,7 @@ from synapse.http.servlet import ( ) from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import canonicalise_email, check_3pid_allowed @@ -54,7 +56,7 @@ logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -103,6 +105,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -164,6 +168,7 @@ class PasswordRestServlet(RestServlet): self.datastore = self.hs.get_datastore() self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() + self.http_client = hs.get_simple_http_client() @interactive_auth_handler async def on_POST(self, request): @@ -189,24 +194,29 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - try: - params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "modify your account password", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + # blindly trust ASes without UI-authing them + if requester.app_service: + params = body + else: + try: + ( + params, + session_id, + ) = await self.auth_handler.validate_user_via_ui_auth( + requester, request, body, "modify your account password", ) - raise + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise user_id = requester.user.to_string() else: requester = None @@ -276,8 +286,28 @@ class PasswordRestServlet(RestServlet): user_id, password_hash, logout_devices, requester ) + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + await self.shadow_password(params, shadow_user.to_string()) + + return 200, {} + + def on_OPTIONS(self, _): return 200, {} + async def shadow_password(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") @@ -372,13 +402,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not (await check_3pid_allowed(self.hs, "email", email)): raise SynapseError( 403, - "Your email domain is not authorized on this server", + "Your email is not authorized on this server", Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -430,7 +462,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -451,13 +483,17 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( 403, "Account phone numbers are not authorized on this server", Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -621,7 +657,8 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = hs.get_datastore() + self.http_client = hs.get_simple_http_client() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) @@ -640,6 +677,29 @@ class ThreepidRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) + # skip validation if this is a shadow 3PID from an AS + if requester.app_service: + # XXX: ASes pass in a validated threepid directly to bypass the IS. + # This makes the API entirely change shape when we have an AS token; + # it really should be an entirely separate API - perhaps + # /account/3pid/replicate or something. + threepid = body.get("threepid") + + await self.auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) + + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + + return 200, {} + threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") if threepid_creds is None: raise SynapseError( @@ -661,12 +721,35 @@ class ThreepidRestServlet(RestServlet): validation_session["address"], validation_session["validated_at"], ) + + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + threepid = { + "medium": validation_session["medium"], + "address": validation_session["address"], + "validated_at": validation_session["validated_at"], + } + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + async def shadow_3pid(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") @@ -677,6 +760,7 @@ class ThreepidAddRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.http_client = hs.get_simple_http_client() @interactive_auth_handler async def on_POST(self, request): @@ -708,12 +792,33 @@ class ThreepidAddRestServlet(RestServlet): validation_session["address"], validation_session["validated_at"], ) + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + threepid = { + "medium": validation_session["medium"], + "address": validation_session["address"], + "validated_at": validation_session["validated_at"], + } + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + async def shadow_3pid(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") @@ -783,6 +888,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.http_client = hs.get_simple_http_client() async def on_POST(self, request): if not self.hs.config.enable_3pid_changes: @@ -807,6 +913,12 @@ class ThreepidDeleteRestServlet(RestServlet): logger.exception("Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid") + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + await self.shadow_3pid_delete(body, shadow_user.to_string()) + if ret: id_server_unbind_result = "success" else: @@ -814,6 +926,74 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} + async def shadow_3pid_delete(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + + +class ThreepidLookupRestServlet(RestServlet): + PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")] + + def __init__(self, hs): + super(ThreepidLookupRestServlet, self).__init__() + self.auth = hs.get_auth() + self.identity_handler = hs.get_identity_handler() + + async def on_GET(self, request): + """Proxy a /_matrix/identity/api/v1/lookup request to an identity + server + """ + await self.auth.get_user_by_req(request) + + # Verify query parameters + query_params = request.args + assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"]) + + # Retrieve needed information from query parameters + medium = parse_string(request, "medium") + address = parse_string(request, "address") + id_server = parse_string(request, "id_server") + + # Proxy the request to the identity server. lookup_3pid handles checking + # if the lookup is allowed so we don't need to do it here. + ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address) + + return 200, ret + + +class ThreepidBulkLookupRestServlet(RestServlet): + PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")] + + def __init__(self, hs): + super(ThreepidBulkLookupRestServlet, self).__init__() + self.auth = hs.get_auth() + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity + server + """ + await self.auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["threepids", "id_server"]) + + # Proxy the request to the identity server. lookup_3pid handles checking + # if the lookup is allowed so we don't need to do it here. + ret = await self.identity_handler.proxy_bulk_lookup_3pid( + body["id_server"], body["threepids"] + ) + + return 200, ret + def assert_valid_next_link(hs: "HomeServer", next_link: str): """ @@ -880,4 +1060,6 @@ def register_servlets(hs, http_server): ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) + ThreepidLookupRestServlet(hs).register(http_server) + ThreepidBulkLookupRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 3f28c0bc3e..c9f13e4ac5 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import UserID from ._base import client_patterns @@ -38,6 +39,9 @@ class AccountDataServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() + self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + self._profile_handler = hs.get_profile_handler() async def on_PUT(self, request, user_id, account_data_type): requester = await self.auth.get_user_by_req(request) @@ -46,7 +50,15 @@ class AccountDataServlet(RestServlet): body = parse_json_object_from_request(request) - await self.handler.add_account_data_for_user(user_id, account_data_type, body) + if account_data_type == "im.vector.hide_profile": + user = UserID.from_string(user_id) + hide_profile = body.get("hide_profile") + await self._profile_handler.set_active([user], not hide_profile, True) + + max_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, body + ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index bd7f9ae203..40c5bd4d8c 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -37,24 +37,38 @@ class AccountValidityRenewServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.success_html = hs.config.account_validity.account_renewed_html_content - self.failure_html = hs.config.account_validity.invalid_token_html_content + self.account_renewed_template = ( + hs.config.account_validity_account_renewed_template + ) + self.account_previously_renewed_template = ( + hs.config.account_validity_account_previously_renewed_template + ) + self.invalid_token_template = hs.config.account_validity_invalid_token_template async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = await self.account_activity_handler.renew_account( + ( + token_valid, + token_stale, + expiration_ts, + ) = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) if token_valid: status_code = 200 - response = self.success_html + response = self.account_renewed_template.render(expiration_ts=expiration_ts) + elif token_stale: + status_code = 200 + response = self.account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) else: status_code = 404 - response = self.failure_html + response = self.invalid_token_template.render(expiration_ts=expiration_ts) respond_with_html(request, status_code, response) @@ -72,10 +86,12 @@ class AccountValiditySendMailServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.account_validity = self.hs.config.account_validity + self.account_validity_renew_by_email_enabled = ( + self.hs.config.account_validity_renew_by_email_enabled + ) async def on_POST(self, request): - if not self.account_validity.renew_by_email_enabled: + if not self.account_validity_renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py new file mode 100644
index 0000000000..8439da447e --- /dev/null +++ b/synapse/rest/client/v2_alpha/knock.py
@@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Optional, Tuple + +from twisted.web.server import Request + +from synapse.api.constants import Membership +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_list_from_args, +) +from synapse.logging.opentracing import set_tag +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict, RoomAlias, RoomID + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class KnockRoomAliasServlet(RestServlet): + """ + POST /xyz.amorgan.knock/{roomIdOrAlias} + """ + + PATTERNS = client_patterns( + "/xyz.amorgan.knock/(?P<room_identifier>[^/]*)", releases=() + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.txns = HttpTransactionCache(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + async def on_POST( + self, request: Request, room_identifier: str, txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + event_content = None + if "reason" in content: + event_content = {"reason": content["reason"]} + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = parse_list_from_args(request.args, "server_name") + except KeyError: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id_obj.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + await self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action=Membership.KNOCK, + txn_id=txn_id, + third_party_signed=None, + remote_room_hosts=remote_room_hosts, + content=event_content, + ) + + return 200, {"room_id": room_id} + + def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) + + +def register_servlets(hs, http_server): + KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..5d461efb6a 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2015 - 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,7 @@ import hmac import logging import random +import re from typing import List, Union import synapse @@ -119,13 +121,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not (await check_3pid_allowed(self.hs, "email", body["email"])): raise SynapseError( 403, - "Your email domain is not authorized to register on this server", + "Your email is not authorized to register on this server", Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -198,13 +202,19 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + assert_valid_client_secret(body["client_secret"]) + + if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( 403, "Phone numbers are not authorized to register on this server", Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) @@ -354,15 +364,9 @@ class UsernameAvailabilityRestServlet(RestServlet): 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) - ip = request.getClientIP() - with self.ratelimiter.ratelimit(ip) as wait_deferred: - await wait_deferred - - username = parse_string(request, "username", required=True) - - await self.registration_handler.check_username(username) - - return 200, {"available": True} + # We are not interested in logging in via a username in this deployment. + # Simply allow anything here as it won't be used later. + return 200, {"available": True} class RegisterRestServlet(RestServlet): @@ -412,18 +416,27 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # Pull out the provided username and do basic sanity checks early since - # the auth layer will store these in sessions. + # We don't care about usernames for this deployment. In fact, the act + # of checking whether they exist already can leak metadata about + # which users are already registered. + # + # Usernames are already derived via the provided email. + # So, if they're not necessary, just ignore them. + # + # (we do still allow appservices to set them below) desired_username = None - if "username" in body: - if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") - desired_username = body["username"] + + desired_display_name = body.get("display_name") appservice = None if self.auth.has_access_token(request): appservice = self.auth.get_appservice_by_req(request) + # We need to retrieve the password early in order to pass it to + # application service registration + # This is specific to shadow server registration of users via an AS + password = body.pop("password", None) + # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -432,7 +445,7 @@ class RegisterRestServlet(RestServlet): # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. - desired_username = body.get("user", desired_username) + desired_username = body.get("user", body.get("username")) # XXX we should check that desired_username is valid. Currently # we give appservices carte blanche for any insanity in mxids, @@ -445,7 +458,7 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Desired Username is missing or not a string") result = await self._do_appservice_registration( - desired_username, access_token, body + desired_username, password, desired_display_name, access_token, body ) return 200, result @@ -454,16 +467,6 @@ class RegisterRestServlet(RestServlet): if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) - # For regular registration, convert the provided username to lowercase - # before attempting to register it. This should mean that people who try - # to register with upper-case in their usernames don't get a nasty surprise. - # - # Note that we treat usernames case-insensitively in login, so they are - # free to carry on imagining that their username is CrAzYh4cKeR if that - # keeps them happy. - if desired_username is not None: - desired_username = desired_username.lower() - # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) @@ -472,7 +475,6 @@ class RegisterRestServlet(RestServlet): # Note that we remove the password from the body since the auth layer # will store the body in the session and we don't want a plaintext # password store there. - password = body.pop("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") @@ -502,14 +504,6 @@ class RegisterRestServlet(RestServlet): session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) - # Ensure that the username is valid. - if desired_username is not None: - await self.registration_handler.check_username( - desired_username, - guest_access_token=guest_access_token, - assigned_user_id=registered_user_id, - ) - # Check if the user-interactive authentication flows are complete, if # not this will raise a user-interactive auth error. try: @@ -546,7 +540,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - if not check_3pid_allowed(self.hs, medium, address): + if not (await check_3pid_allowed(self.hs, medium, address)): raise SynapseError( 403, "Third party identifiers (email/phone numbers)" @@ -554,6 +548,80 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + existingUid = await self.store.get_user_id_by_threepid( + medium, address + ) + + if existingUid is not None: + raise SynapseError( + 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE + ) + + if self.hs.config.register_mxid_from_3pid: + # override the desired_username based on the 3PID if any. + # reset it first to avoid folks picking their own username. + desired_username = None + + # we should have an auth_result at this point if we're going to progress + # to register the user (i.e. we haven't picked up a registered_user_id + # from our session store), in which case get ready and gen the + # desired_username + if auth_result: + if ( + self.hs.config.register_mxid_from_3pid == "email" + and LoginType.EMAIL_IDENTITY in auth_result + ): + address = auth_result[LoginType.EMAIL_IDENTITY]["address"] + desired_username = synapse.types.strip_invalid_mxid_characters( + address.replace("@", "-").lower() + ) + + # find a unique mxid for the account, suffixing numbers + # if needed + while True: + try: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + # if we got this far we passed the check. + break + except SynapseError as e: + if e.errcode == Codes.USER_IN_USE: + m = re.match(r"^(.*?)(\d+)$", desired_username) + if m: + desired_username = m.group(1) + str( + int(m.group(2)) + 1 + ) + else: + desired_username += "1" + else: + # something else went wrong. + break + + if self.hs.config.register_just_use_email_for_display_name: + desired_display_name = address + else: + # Custom mapping between email address and display name + desired_display_name = _map_email_to_displayname(address) + elif ( + self.hs.config.register_mxid_from_3pid == "msisdn" + and LoginType.MSISDN in auth_result + ): + desired_username = auth_result[LoginType.MSISDN]["address"] + else: + raise SynapseError( + 400, "Cannot derive mxid from 3pid; no recognised 3pid" + ) + + if desired_username is not None: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", registered_user_id @@ -568,7 +636,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = params.get("username", None) + if not self.hs.config.register_mxid_from_3pid: + desired_username = params.get("username", None) + else: + # we keep the original desired_username derived from the 3pid above + pass + guest_access_token = params.get("guest_access_token", None) if desired_username is not None: @@ -617,6 +690,7 @@ class RegisterRestServlet(RestServlet): localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, + default_display_name=desired_display_name, threepid=threepid, address=client_addr, user_agent_ips=entries, @@ -629,6 +703,14 @@ class RegisterRestServlet(RestServlet): ): await self.store.upsert_monthly_active_user(registered_user_id) + if self.hs.config.shadow_server: + await self.registration_handler.shadow_register( + localpart=desired_username, + display_name=desired_display_name, + auth_result=auth_result, + params=params, + ) + # Remember that the user account has been registered (and the user # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( @@ -652,14 +734,40 @@ class RegisterRestServlet(RestServlet): return 200, return_dict - async def _do_appservice_registration(self, username, as_token, body): + async def _do_appservice_registration( + self, username, password, display_name, as_token, body + ): + # FIXME: appservice_register() is horribly duplicated with register() + # and they should probably just be combined together with a config flag. + + if password: + # Hash the password + # + # In mainline hashing of the password was moved further on in the registration + # flow, but we need it here for the AS use-case of shadow servers + password = await self.auth_handler.hash(password) user_id = await self.registration_handler.appservice_register( - username, as_token + username, as_token, password, display_name ) - return await self._create_registration_details( + result = await self._create_registration_details( user_id, body, is_appservice_ghost=True, ) + auth_result = body.get("auth_result") + if auth_result and LoginType.EMAIL_IDENTITY in auth_result: + threepid = auth_result[LoginType.EMAIL_IDENTITY] + await self.registration_handler.register_email_threepid( + user_id, threepid, result["access_token"] + ) + + if auth_result and LoginType.MSISDN in auth_result: + threepid = auth_result[LoginType.MSISDN] + await self.registration_handler.register_msisdn_threepid( + user_id, threepid, result["access_token"] + ) + + return result + async def _create_registration_details( self, user_id, params, is_appservice_ghost=False ): @@ -715,6 +823,60 @@ class RegisterRestServlet(RestServlet): ) +def cap(name): + """Capitalise parts of a name containing different words, including those + separated by hyphens. + For example, 'John-Doe' + + Args: + name (str): The name to parse + """ + if not name: + return name + + # Split the name by whitespace then hyphens, capitalizing each part then + # joining it back together. + capatilized_name = " ".join( + "-".join(part.capitalize() for part in space_part.split("-")) + for space_part in name.split() + ) + return capatilized_name + + +def _map_email_to_displayname(address): + """Custom mapping from an email address to a user displayname + + Args: + address (str): The email address to process + Returns: + str: The new displayname + """ + # Split the part before and after the @ in the email. + # Replace all . with spaces in the first part + parts = address.replace(".", " ").split("@") + + # Figure out which org this email address belongs to + org_parts = parts[1].split(" ") + + # If this is a ...matrix.org email, mark them as an Admin + if org_parts[-2] == "matrix" and org_parts[-1] == "org": + org = "Tchap Admin" + + # Is this is a ...gouv.fr address, set the org to whatever is before + # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their + # org as "gouv" + elif org_parts[-2] == "gouv" and org_parts[-1] == "fr": + org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2] + + # Otherwise, mark their org as the email's second-level domain name + else: + org = org_parts[-2] + + desired_display_name = cap(parts[0]) + " [" + cap(org) + "]" + + return desired_display_name + + def _calculate_registration_flows( # technically `config` has to provide *all* of these interfaces, not just one config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8e52e4cca4..582c999abd 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging +from typing import Any, Callable, Dict, List -from synapse.api.constants import PresenceState +from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( @@ -24,7 +24,7 @@ from synapse.events.utils import ( format_event_raw, ) from synapse.handlers.presence import format_user_presence_state -from synapse.handlers.sync import SyncConfig +from synapse.handlers.sync import KnockedSyncResult, SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken from synapse.util import json_decoder @@ -213,6 +213,10 @@ class SyncRestServlet(RestServlet): sync_result.invited, time_now, access_token_id, event_formatter ) + knocked = await self.encode_knocked( + sync_result.knocked, time_now, access_token_id, event_formatter + ) + archived = await self.encode_archived( sync_result.archived, time_now, @@ -230,11 +234,16 @@ class SyncRestServlet(RestServlet): "left": list(sync_result.device_lists.left), }, "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now), - "rooms": {"join": joined, "invite": invited, "leave": archived}, + "rooms": { + Membership.JOIN: joined, + Membership.INVITE: invited, + Membership.KNOCK: knocked, + Membership.LEAVE: archived, + }, "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, + Membership.JOIN: sync_result.groups.join, + Membership.INVITE: sync_result.groups.invite, + Membership.LEAVE: sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, @@ -296,7 +305,7 @@ class SyncRestServlet(RestServlet): Args: rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of - sync results for rooms this user is joined to + sync results for rooms this user is invited to time_now(int): current time - used as a baseline for age calculations token_id(int): ID of the user's auth token - used for namespacing @@ -315,7 +324,7 @@ class SyncRestServlet(RestServlet): time_now, token_id=token_id, event_format=event_formatter, - is_invite=True, + include_stripped_room_state=True, ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned @@ -325,6 +334,60 @@ class SyncRestServlet(RestServlet): return invited + async def encode_knocked( + self, + rooms: List[KnockedSyncResult], + time_now: int, + token_id: int, + event_formatter: Callable[[Dict], Dict], + ) -> Dict[str, Dict[str, Any]]: + """ + Encode the rooms we've knocked on in a sync result. + + Args: + rooms: list of sync results for rooms this user is knocking on + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs + event_formatter: function to convert from federation format to client format + + Returns: + The list of rooms the user has knocked on, in our response format. + """ + knocked = {} + for room in rooms: + knock = await self._event_serializer.serialize_event( + room.knock, + time_now, + token_id=token_id, + event_format=event_formatter, + include_stripped_room_state=True, + ) + + # Extract the `unsigned` key from the knock event. + # This is where we (cheekily) store the knock state events + unsigned = knock.setdefault("unsigned", {}) + + # Duplicate the dictionary in order to avoid modifying the original + unsigned = dict(unsigned) + + # Extract the stripped room state from the unsigned dict + # This is for clients to get a little bit of information about + # the room they've knocked on, without revealing any sensitive information + knocked_state = list(unsigned.pop("knock_room_state", [])) + + # Append the actual knock membership event itself as well. This provides + # the client with: + # + # * A knock state event that they can use for easier internal tracking + # * The rough timestamp of when the knock occurred contained within the event + knocked_state.append(knock) + + # Build the `knock_state` dictionary, which will contain the state of the + # room that the client has knocked on + knocked[room.room_id] = {"knock_state": {"events": knocked_state}} + + return knocked + async def encode_archived( self, rooms, time_now, token_id, event_fields, event_formatter ): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index ad598cefe0..eeddfa31f8 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@ # limitations under the License. import logging +from typing import Dict -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from signedjson.sign import sign_json + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.types import UserID from ._base import client_patterns @@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() async def on_POST(self, request): """Searches for users in directory @@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet): body = parse_json_object_from_request(request) + if self.hs.config.user_directory_defer_to_id_server: + signed_body = sign_json( + body, self.hs.hostname, self.hs.config.signing_key[0] + ) + url = "%s/_matrix/identity/api/v1/user_directory/search" % ( + self.hs.config.user_directory_defer_to_id_server, + ) + resp = await self.http_client.post_json_get_json(url, signed_body) + return 200, resp + limit = body.get("limit", 10) limit = min(limit, 50) @@ -76,5 +95,124 @@ class UserDirectorySearchRestServlet(RestServlet): return 200, results +class SingleUserInfoServlet(RestServlet): + """ + Deprecated and replaced by `/users/info` + + GET /user/{user_id}/info HTTP/1.1 + """ + + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$") + + def __init__(self, hs): + super(SingleUserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + registry = hs.get_federation_registry() + + if not registry.query_handlers.get("user_info"): + registry.register_query_handler("user_info", self._on_federation_query) + + async def on_GET(self, request, user_id): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + # Attempt to make a federation request to the server that owns this user + args = {"user_id": user_id} + res = await self.transport_layer.make_query( + user.domain, "user_info", args, retry_on_dns_fail=True + ) + return 200, res + + user_id_to_info = await self.store.get_info_for_users([user_id]) + return 200, user_id_to_info[user_id] + + async def _on_federation_query(self, args): + """Called when a request for user information appears over federation + + Args: + args (dict): Dictionary of query arguments provided by the request + + Returns: + Deferred[dict]: Deactivation and expiration information for a given user + """ + user_id = args.get("user_id") + if not user_id: + raise SynapseError(400, "user_id not provided") + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + raise SynapseError(400, "User is not hosted on this homeserver") + + user_ids_to_info_dict = await self.store.get_info_for_users([user_id]) + return user_ids_to_info_dict[user_id] + + +class UserInfoServlet(RestServlet): + """Bulk version of `/user/{user_id}/info` endpoint + + GET /users/info HTTP/1.1 + + Returns a dictionary of user_id to info dictionary. Supports remote users + """ + + PATTERNS = client_patterns("/users/info$", unstable=True, releases=()) + + def __init__(self, hs): + super(UserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + + async def on_POST(self, request): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + # Extract the user_ids from the request + body = parse_json_object_from_request(request) + assert_params_in_dict(body, required=["user_ids"]) + + user_ids = body["user_ids"] + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + # Separate local and remote users + local_user_ids = set() + remote_server_to_user_ids = {} # type: Dict[str, set] + for user_id in user_ids: + user = UserID.from_string(user_id) + + if self.hs.is_mine(user): + local_user_ids.add(user_id) + else: + remote_server_to_user_ids.setdefault(user.domain, set()) + remote_server_to_user_ids[user.domain].add(user_id) + + # Retrieve info of all local users + user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids) + + # Request info of each remote user from their remote homeserver + for server_name, user_id_set in remote_server_to_user_ids.items(): + # Make a request to the given server about their own users + res = await self.transport_layer.get_info_of_users( + server_name, list(user_id_set) + ) + + user_id_to_info_dict.update(res) + + return 200, user_id_to_info_dict + + def register_servlets(hs, http_server): UserDirectorySearchRestServlet(hs).register(http_server) + SingleUserInfoServlet(hs).register(http_server) + UserInfoServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d24a199318..c9b9e7f5ff 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -72,9 +72,12 @@ class VersionsRestServlet(RestServlet): # MSC2326. "org.matrix.label_based_filtering": True, # Implements support for cross signing as described in MSC1756 - "org.matrix.e2e_cross_signing": True, + # "org.matrix.e2e_cross_signing": True, # Implements additional endpoints as described in MSC2432 "org.matrix.msc2432": True, + # Tchap does not currently assume this rule for r0.5.0 + # XXX: Remove this when it does + "m.lazy_load_members": True, # Implements additional endpoints as described in MSC2666 "uk.half-shot.msc2666": True, # Whether new rooms will be set to encrypted or not (based on presets). diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..f71a03a12d 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -300,6 +300,7 @@ class FileInfo: thumbnail_height (int) thumbnail_method (str) thumbnail_type (str): Content type of thumbnail, e.g. image/png + thumbnail_length (int): The size of the media file, in bytes. """ def __init__( @@ -312,6 +313,7 @@ class FileInfo: thumbnail_height=None, thumbnail_method=None, thumbnail_type=None, + thumbnail_length=None, ): self.server_name = server_name self.file_id = file_id @@ -321,6 +323,7 @@ class FileInfo: self.thumbnail_height = thumbnail_height self.thumbnail_method = thumbnail_method self.thumbnail_type = thumbnail_type + self.thumbnail_length = thumbnail_length def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..5ac307a62d 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -146,8 +146,7 @@ class PreviewUrlResource(DirectServeJsonResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), + use_proxy=True, ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path @@ -386,7 +385,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Check whether the URL should be downloaded as oEmbed content instead. - Params: + Args: url: The URL to check. Returns: @@ -403,7 +402,7 @@ class PreviewUrlResource(DirectServeJsonResource): """ Request content from an oEmbed endpoint. - Params: + Args: endpoint: The oEmbed API endpoint. url: The URL to pass to the API. @@ -692,27 +691,51 @@ class PreviewUrlResource(DirectServeJsonResource): def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: + """ + Calculate metadata for an HTML document. + + This uses lxml to parse the HTML document into the OG response. If errors + occur during processing of the document, an empty response is returned. + + Args: + body: The HTML document, as bytes. + media_url: The URI used to download the body. + request_encoding: The character encoding of the body, as a string. + + Returns: + The OG response as a dictionary. + """ # If there's no body, nothing useful is going to be found. if not body: return {} from lxml import etree + # Create an HTML parser. If this fails, log and return no metadata. try: parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body, parser) - og = _calc_og(tree, media_uri) + except LookupError: + # blindly consider the encoding as utf-8. + parser = etree.HTMLParser(recover=True, encoding="utf-8") + except Exception as e: + logger.warning("Unable to create HTML parser: %s" % (e,)) + return {} + + def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: + # Attempt to parse the body. If this fails, log and return no metadata. + tree = etree.fromstring(body_attempt, parser) + return _calc_og(tree, media_uri) + + # Attempt to parse the body. If this fails, log and return no metadata. + try: + return _attempt_calc_og(body) except UnicodeDecodeError: # blindly try decoding the body as utf-8, which seems to fix # the charset mismatches on https://google.com - parser = etree.HTMLParser(recover=True, encoding=request_encoding) - tree = etree.fromstring(body.decode("utf-8", "ignore"), parser) - og = _calc_og(tree, media_uri) - - return og + return _attempt_calc_og(body.decode("utf-8", "ignore")) -def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]: +def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]: # suck our tree into lxml and define our OG response. # if we see any image URLs in the OG response, then spider them diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional from twisted.web.http import Request @@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource): return thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) - - if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos - ) - - file_info = FileInfo( - server_name=None, - file_id=media_id, - url_cache=media_info["url_cache"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] - - responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) - else: - logger.info("Couldn't find any generated thumbnails") - respond_404(request) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_id, + url_cache=media_info["url_cache"], + server_name=None, + ) async def _select_or_generate_local_thumbnail( self, @@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource): thumbnail_infos = await self.store.get_remote_media_thumbnails( server_name, media_id ) + await self._select_and_respond_with_thumbnail( + request, + width, + height, + method, + m_type, + thumbnail_infos, + media_info["filesystem_id"], + url_cache=None, + server_name=server_name, + ) + async def _select_and_respond_with_thumbnail( + self, + request: Request, + desired_width: int, + desired_height: int, + desired_method: str, + desired_type: str, + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str] = None, + server_name: Optional[str] = None, + ) -> None: + """ + Respond to a request with an appropriate thumbnail from the previously generated thumbnails. + + Args: + request: The incoming request. + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + """ if thumbnail_infos: - thumbnail_info = self._select_thumbnail( - width, height, method, m_type, thumbnail_infos + file_info = self._select_thumbnail( + desired_width, + desired_height, + desired_method, + desired_type, + thumbnail_infos, + file_id, + url_cache, + server_name, ) - file_info = FileInfo( - server_name=server_name, - file_id=media_info["filesystem_id"], - thumbnail=True, - thumbnail_width=thumbnail_info["thumbnail_width"], - thumbnail_height=thumbnail_info["thumbnail_height"], - thumbnail_type=thumbnail_info["thumbnail_type"], - thumbnail_method=thumbnail_info["thumbnail_method"], - ) - - t_type = file_info.thumbnail_type - t_length = thumbnail_info["thumbnail_length"] + if not file_info: + logger.info("Couldn't find a thumbnail matching the desired inputs") + respond_404(request) + return responder = await self.media_storage.fetch_media(file_info) - await respond_with_responder(request, responder, t_type, t_length) + await respond_with_responder( + request, responder, file_info.thumbnail_type, file_info.thumbnail_length + ) else: logger.info("Failed to find any generated thumbnails") respond_404(request) @@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource): desired_height: int, desired_method: str, desired_type: str, - thumbnail_infos, - ) -> dict: + thumbnail_infos: List[Dict[str, Any]], + file_id: str, + url_cache: Optional[str], + server_name: Optional[str], + ) -> Optional[FileInfo]: + """ + Choose an appropriate thumbnail from the previously generated thumbnails. + + Args: + desired_width: The desired width, the returned thumbnail may be larger than this. + desired_height: The desired height, the returned thumbnail may be larger than this. + desired_method: The desired method used to generate the thumbnail. + desired_type: The desired content-type of the thumbnail. + thumbnail_infos: A list of dictionaries of candidate thumbnails. + file_id: The ID of the media that a thumbnail is being requested for. + url_cache: The URL cache value. + server_name: The server name, if this is a remote thumbnail. + + Returns: + The thumbnail which best matches the desired parameters. + """ + desired_method = desired_method.lower() + + # The chosen thumbnail. + thumbnail_info = None + d_w = desired_width d_h = desired_height - if desired_method.lower() == "crop": + if desired_method == "crop": + # Thumbnails that match equal or larger sizes of desired width/height. crop_info_list = [] + # Other thumbnails. crop_info_list2 = [] for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "crop": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] - if t_method == "crop": - aspect_quality = abs(d_w * t_h - d_h * t_w) - min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 - size_quality = abs((d_w - t_w) * (d_h - t_h)) - type_quality = desired_type != info["thumbnail_type"] - length_quality = info["thumbnail_length"] - if t_w >= d_w or t_h >= d_h: - crop_info_list.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + aspect_quality = abs(d_w * t_h - d_h * t_w) + min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 + size_quality = abs((d_w - t_w) * (d_h - t_h)) + type_quality = desired_type != info["thumbnail_type"] + length_quality = info["thumbnail_length"] + if t_w >= d_w or t_h >= d_h: + crop_info_list.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) - else: - crop_info_list2.append( - ( - aspect_quality, - min_quality, - size_quality, - type_quality, - length_quality, - info, - ) + ) + else: + crop_info_list2.append( + ( + aspect_quality, + min_quality, + size_quality, + type_quality, + length_quality, + info, ) + ) if crop_info_list: - return min(crop_info_list)[-1] - else: - return min(crop_info_list2)[-1] - else: + thumbnail_info = min(crop_info_list)[-1] + elif crop_info_list2: + thumbnail_info = min(crop_info_list2)[-1] + elif desired_method == "scale": + # Thumbnails that match equal or larger sizes of desired width/height. info_list = [] + # Other thumbnails. info_list2 = [] + for info in thumbnail_infos: + # Skip thumbnails generated with different methods. + if info["thumbnail_method"] != "scale": + continue + t_w = info["thumbnail_width"] t_h = info["thumbnail_height"] - t_method = info["thumbnail_method"] size_quality = abs((d_w - t_w) * (d_h - t_h)) type_quality = desired_type != info["thumbnail_type"] length_quality = info["thumbnail_length"] - if t_method == "scale" and (t_w >= d_w or t_h >= d_h): + if t_w >= d_w or t_h >= d_h: info_list.append((size_quality, type_quality, length_quality, info)) - elif t_method == "scale": + else: info_list2.append( (size_quality, type_quality, length_quality, info) ) if info_list: - return min(info_list)[-1] - else: - return min(info_list2)[-1] + thumbnail_info = min(info_list)[-1] + elif info_list2: + thumbnail_info = min(info_list2)[-1] + + if thumbnail_info: + return FileInfo( + file_id=file_id, + url_cache=url_cache, + server_name=server_name, + thumbnail=True, + thumbnail_width=thumbnail_info["thumbnail_width"], + thumbnail_height=thumbnail_info["thumbnail_height"], + thumbnail_type=thumbnail_info["thumbnail_type"], + thumbnail_method=thumbnail_info["thumbnail_method"], + thumbnail_length=thumbnail_info["thumbnail_length"], + ) + + # No matching thumbnail was found. + return None diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py new file mode 100644
index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py new file mode 100644
index 0000000000..6f2a1931c5 --- /dev/null +++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.config._base import ConfigError + +logger = logging.getLogger(__name__) + + +class DomainRuleChecker(object): + """ + A re-implementation of the SpamChecker that prevents users in one domain from + inviting users in other domains to rooms, based on a configuration. + + Takes a config in the format: + + spam_checker: + module: "rulecheck.DomainRuleChecker" + config: + domain_mapping: + "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ] + "other_inviter_domain": [ "invitee_domain_permitted" ] + default: False + + # Only let local users join rooms if they were explicitly invited. + can_only_join_rooms_with_invite: false + + # Only let local users create rooms if they are inviting only one + # other user, and that user matches the rules above. + can_only_create_one_to_one_rooms: false + + # Only let local users invite during room creation, regardless of the + # domain mapping rules above. + can_only_invite_during_room_creation: false + + # Prevent local users from inviting users from certain domains to + # rooms published in the room directory. + domains_prevented_from_being_invited_to_published_rooms: [] + + # Allow third party invites + can_invite_by_third_party_id: true + + Don't forget to consider if you can invite users from your own domain. + """ + + def __init__(self, config): + self.domain_mapping = config["domain_mapping"] or {} + self.default = config["default"] + + self.can_only_join_rooms_with_invite = config.get( + "can_only_join_rooms_with_invite", False + ) + self.can_only_create_one_to_one_rooms = config.get( + "can_only_create_one_to_one_rooms", False + ) + self.can_only_invite_during_room_creation = config.get( + "can_only_invite_during_room_creation", False + ) + self.can_invite_by_third_party_id = config.get( + "can_invite_by_third_party_id", True + ) + self.domains_prevented_from_being_invited_to_published_rooms = config.get( + "domains_prevented_from_being_invited_to_published_rooms", [] + ) + + def check_event_for_spam(self, event): + """Implements synapse.events.SpamChecker.check_event_for_spam + """ + return False + + def user_may_invite( + self, + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, + published_room=False, + ): + """Implements synapse.events.SpamChecker.user_may_invite + """ + if self.can_only_invite_during_room_creation and not new_room: + return False + + if not self.can_invite_by_third_party_id and third_party_invite: + return False + + # This is a third party invite (without a bound mxid), so unless we have + # banned all third party invites (above) we allow it. + if not invitee_userid: + return True + + inviter_domain = self._get_domain_from_id(inviter_userid) + invitee_domain = self._get_domain_from_id(invitee_userid) + + if inviter_domain not in self.domain_mapping: + return self.default + + if ( + published_room + and invitee_domain + in self.domains_prevented_from_being_invited_to_published_rooms + ): + return False + + return invitee_domain in self.domain_mapping[inviter_domain] + + def user_may_create_room( + self, userid, invite_list, third_party_invite_list, cloning + ): + """Implements synapse.events.SpamChecker.user_may_create_room + """ + + if cloning: + return True + + if not self.can_invite_by_third_party_id and third_party_invite_list: + return False + + number_of_invites = len(invite_list) + len(third_party_invite_list) + + if self.can_only_create_one_to_one_rooms and number_of_invites != 1: + return False + + return True + + def user_may_create_room_alias(self, userid, room_alias): + """Implements synapse.events.SpamChecker.user_may_create_room_alias + """ + return True + + def user_may_publish_room(self, userid, room_id): + """Implements synapse.events.SpamChecker.user_may_publish_room + """ + return True + + def user_may_join_room(self, userid, room_id, is_invited): + """Implements synapse.events.SpamChecker.user_may_join_room + """ + if self.can_only_join_rooms_with_invite and not is_invited: + return False + + return True + + @staticmethod + def parse_config(config): + """Implements synapse.events.SpamChecker.parse_config + """ + if "default" in config: + return config + else: + raise ConfigError("No default set for spam_config DomainRuleChecker") + + @staticmethod + def _get_domain_from_id(mxid): + """Parses a string and returns the domain part of the mxid. + + Args: + mxid (str): a valid mxid + + Returns: + str: the domain part of the mxid + + """ + idx = mxid.find(":") + if idx == -1: + raise Exception("Invalid ID: %r" % (mxid,)) + return mxid[idx + 1 :] diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..6ffb7e0fd9 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -24,7 +24,6 @@ import abc import functools import logging -import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast import twisted.internet.base @@ -103,6 +102,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -128,6 +128,8 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.saml_handler import SamlHandler @@ -357,11 +359,7 @@ class HomeServer(metaclass=abc.ABCMeta): """ An HTTP client that uses configured HTTP(S) proxies. """ - return SimpleHttpClient( - self, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), - ) + return SimpleHttpClient(self, use_proxy=True) @cache_in_self def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: @@ -373,8 +371,7 @@ class HomeServer(metaclass=abc.ABCMeta): self, ip_whitelist=self.config.ip_range_whitelist, ip_blacklist=self.config.ip_range_blacklist, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), + use_proxy=True, ) @cache_in_self @@ -716,6 +713,33 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_external_cache(self) -> ExternalCache: + return ExternalCache(self) + + @cache_in_self + def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: + if not self.config.redis.redis_enabled: + return None + + # We only want to import redis module if we're using it, as we have + # `txredisapi` as an optional dependency. + from synapse.replication.tcp.redis import lazyConnection + + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis_host, + self.config.redis_port, + ) + + return lazyConnection( + hs=self, + host=self.config.redis_host, + port=self.config.redis_port, + password=self.config.redis.redis_password, + reconnect=True, + ) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None + entry = None else: # otherwise, we'll need to resolve the state across the prev_events. @@ -340,9 +341,13 @@ class StateHandler: current_state_ids=state_ids_before_event, ) - # XXX: can we update the state cache entry for the new state group? or - # could we set a flag on resolve_state_groups_for_events to tell it to - # always make a state group? + # Assign the new state group to the cached state entry. + # + # Note that this can race in that we could generate multiple state + # groups for the same state entry, but that is just inefficient + # rather than dangerous. + if entry and entry.state_group is None: + entry.state_group = state_group_before_event # # now if it's not a state event, we're done diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index c7220bc778..d2ba4bd2fc 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -262,6 +262,12 @@ class LoggingTransaction: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: + """Similar to `executemany`, except `txn.rowcount` will not be correct + afterwards. + + More efficient than `executemany` on PostgreSQL + """ + if isinstance(self.database_engine, PostgresEngine): from psycopg2.extras import execute_batch # type: ignore diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore from .event_federation import EventFederationStore from .event_push_actions import EventPushActionsStore from .events_bg_updates import EventsBackgroundUpdatesStore +from .events_forward_extremities import EventForwardExtremitiesStore from .filtering import FilteringStore from .group_server import GroupServerStore from .keys import KeyStore @@ -118,6 +119,7 @@ class DataStore( UIAuthStore, CacheInvalidationWorkerStore, ServerMetricsStore, + EventForwardExtremitiesStore, ): def __init__(self, database: DatabasePool, db_conn, hs): self.hs = hs diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore): DELETE FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? """ - txn.executemany(sql, ((row[0], row[1]) for row in rows)) + txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) logger.info("Pruned %d device list outbound pokes", count) @@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # Delete older entries in the table, as we really only care about # when the latest change happened. - txn.executemany( + txn.execute_batch( """ DELETE FROM device_lists_stream WHERE user_id = ? AND device_id = ? AND stream_id < ? diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c128889bf9..309f1e865b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, dict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + ) -> Dict[str, Dict[str, Dict[str, str]]]: """Take a list of one time keys out of the database. Args: diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): VALUES (?, ?, ?, ?, ?, ?) """ - txn.executemany( + txn.execute_batch( sql, ( _gen_entry(user_id, actions) @@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ], ) - txn.executemany( + txn.execute_batch( """ UPDATE event_push_summary SET notif_count = ?, unread_count = ?, stream_ordering = ? diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5db7d7aaa8..ccda9f1caa 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore: txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, ) - @staticmethod + @classmethod def _add_chain_cover_index( + cls, txn, db_pool: DatabasePool, event_to_room_id: Dict[str, str], @@ -614,60 +615,17 @@ class PersistEventsStore: if not events_to_calc_chain_id_for: return - # We now calculate the chain IDs/sequence numbers for the events. We - # do this by looking at the chain ID and sequence number of any auth - # event with the same type/state_key and incrementing the sequence - # number by one. If there was no match or the chain ID/sequence - # number is already taken we generate a new chain. - # - # We need to do this in a topologically sorted order as we want to - # generate chain IDs/sequence numbers of an event's auth events - # before the event itself. - chains_tuples_allocated = set() # type: Set[Tuple[int, int]] - new_chain_tuples = {} # type: Dict[str, Tuple[int, int]] - for event_id in sorted_topologically( - events_to_calc_chain_id_for, event_to_auth_chain - ): - existing_chain_id = None - for auth_id in event_to_auth_chain.get(event_id, []): - if event_to_types.get(event_id) == event_to_types.get(auth_id): - existing_chain_id = chain_map[auth_id] - break - - new_chain_tuple = None - if existing_chain_id: - # We found a chain ID/sequence number candidate, check its - # not already taken. - proposed_new_id = existing_chain_id[0] - proposed_new_seq = existing_chain_id[1] + 1 - if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated: - already_allocated = db_pool.simple_select_one_onecol_txn( - txn, - table="event_auth_chains", - keyvalues={ - "chain_id": proposed_new_id, - "sequence_number": proposed_new_seq, - }, - retcol="event_id", - allow_none=True, - ) - if already_allocated: - # Mark it as already allocated so we don't need to hit - # the DB again. - chains_tuples_allocated.add((proposed_new_id, proposed_new_seq)) - else: - new_chain_tuple = ( - proposed_new_id, - proposed_new_seq, - ) - - if not new_chain_tuple: - new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1) - - chains_tuples_allocated.add(new_chain_tuple) - - chain_map[event_id] = new_chain_tuple - new_chain_tuples[event_id] = new_chain_tuple + # Allocate chain ID/sequence numbers to each new event. + new_chain_tuples = cls._allocate_chain_ids( + txn, + db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, + events_to_calc_chain_id_for, + chain_map, + ) + chain_map.update(new_chain_tuples) db_pool.simple_insert_many_txn( txn, @@ -794,6 +752,137 @@ class PersistEventsStore: ], ) + @staticmethod + def _allocate_chain_ids( + txn, + db_pool: DatabasePool, + event_to_room_id: Dict[str, str], + event_to_types: Dict[str, Tuple[str, str]], + event_to_auth_chain: Dict[str, List[str]], + events_to_calc_chain_id_for: Set[str], + chain_map: Dict[str, Tuple[int, int]], + ) -> Dict[str, Tuple[int, int]]: + """Allocates, but does not persist, chain ID/sequence numbers for the + events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index + for info on args) + """ + + # We now calculate the chain IDs/sequence numbers for the events. We do + # this by looking at the chain ID and sequence number of any auth event + # with the same type/state_key and incrementing the sequence number by + # one. If there was no match or the chain ID/sequence number is already + # taken we generate a new chain. + # + # We try to reduce the number of times that we hit the database by + # batching up calls, to make this more efficient when persisting large + # numbers of state events (e.g. during joins). + # + # We do this by: + # 1. Calculating for each event which auth event will be used to + # inherit the chain ID, i.e. converting the auth chain graph to a + # tree that we can allocate chains on. We also keep track of which + # existing chain IDs have been referenced. + # 2. Fetching the max allocated sequence number for each referenced + # existing chain ID, generating a map from chain ID to the max + # allocated sequence number. + # 3. Iterating over the tree and allocating a chain ID/seq no. to the + # new event, by incrementing the sequence number from the + # referenced event's chain ID/seq no. and checking that the + # incremented sequence number hasn't already been allocated (by + # looking in the map generated in the previous step). We generate a + # new chain if the sequence number has already been allocated. + # + + existing_chains = set() # type: Set[int] + tree = [] # type: List[Tuple[str, Optional[str]]] + + # We need to do this in a topologically sorted order as we want to + # generate chain IDs/sequence numbers of an event's auth events before + # the event itself. + for event_id in sorted_topologically( + events_to_calc_chain_id_for, event_to_auth_chain + ): + for auth_id in event_to_auth_chain.get(event_id, []): + if event_to_types.get(event_id) == event_to_types.get(auth_id): + existing_chain_id = chain_map.get(auth_id) + if existing_chain_id: + existing_chains.add(existing_chain_id[0]) + + tree.append((event_id, auth_id)) + break + else: + tree.append((event_id, None)) + + # Fetch the current max sequence number for each existing referenced chain. + sql = """ + SELECT chain_id, MAX(sequence_number) FROM event_auth_chains + WHERE %s + GROUP BY chain_id + """ + clause, args = make_in_list_sql_clause( + db_pool.engine, "chain_id", existing_chains + ) + txn.execute(sql % (clause,), args) + + chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int] + + # Allocate the new events chain ID/sequence numbers. + # + # To reduce the number of calls to the database we don't allocate a + # chain ID number in the loop, instead we use a temporary `object()` for + # each new chain ID. Once we've done the loop we generate the necessary + # number of new chain IDs in one call, replacing all temporary + # objects with real allocated chain IDs. + + unallocated_chain_ids = set() # type: Set[object] + new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]] + for event_id, auth_event_id in tree: + # If we reference an auth_event_id we fetch the allocated chain ID, + # either from the existing `chain_map` or the newly generated + # `new_chain_tuples` map. + existing_chain_id = None + if auth_event_id: + existing_chain_id = new_chain_tuples.get(auth_event_id) + if not existing_chain_id: + existing_chain_id = chain_map[auth_event_id] + + new_chain_tuple = None # type: Optional[Tuple[Any, int]] + if existing_chain_id: + # We found a chain ID/sequence number candidate, check its + # not already taken. + proposed_new_id = existing_chain_id[0] + proposed_new_seq = existing_chain_id[1] + 1 + + if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq: + new_chain_tuple = ( + proposed_new_id, + proposed_new_seq, + ) + + # If we need to start a new chain we allocate a temporary chain ID. + if not new_chain_tuple: + new_chain_tuple = (object(), 1) + unallocated_chain_ids.add(new_chain_tuple[0]) + + new_chain_tuples[event_id] = new_chain_tuple + chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1] + + # Generate new chain IDs for all unallocated chain IDs. + newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn( + txn, len(unallocated_chain_ids) + ) + + # Map from potentially temporary chain ID to real chain ID + chain_id_to_allocated_map = dict( + zip(unallocated_chain_ids, newly_allocated_chain_ids) + ) # type: Dict[Any, int] + chain_id_to_allocated_map.update((c, c) for c in existing_chains) + + return { + event_id: (chain_id_to_allocated_map[chain_id], seq) + for event_id, (chain_id, seq) in new_chain_tuples.items() + } + def _persist_transaction_ids_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_txn(txn): sql = ( "SELECT stream_ordering, event_id, json FROM events" @@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" - for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): - clump = update_rows[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, update_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 - def reindex_search_txn(txn): sql = ( "SELECT stream_ordering, event_id FROM events" @@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" - for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE): - clump = rows_to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + txn.execute_batch(sql, rows_to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py new file mode 100644
index 0000000000..0ac1da9c35 --- /dev/null +++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict, List + +from synapse.api.errors import SynapseError +from synapse.storage._base import SQLBaseStore + +logger = logging.getLogger(__name__) + + +class EventForwardExtremitiesStore(SQLBaseStore): + async def delete_forward_extremities_for_room(self, room_id: str) -> int: + """Delete any extra forward extremities for a room. + + Invalidates the "get_latest_event_ids_in_room" cache if any forward + extremities were deleted. + + Returns count deleted. + """ + + def delete_forward_extremities_for_room_txn(txn): + # First we need to get the event_id to not delete + sql = """ + SELECT event_id FROM event_forward_extremities + INNER JOIN events USING (room_id, event_id) + WHERE room_id = ? + ORDER BY stream_ordering DESC + LIMIT 1 + """ + txn.execute(sql, (room_id,)) + rows = txn.fetchall() + try: + event_id = rows[0][0] + logger.debug( + "Found event_id %s as the forward extremity to keep for room %s", + event_id, + room_id, + ) + except KeyError: + msg = "No forward extremity event found for room %s" % room_id + logger.warning(msg) + raise SynapseError(400, msg) + + # Now delete the extra forward extremities + sql = """ + DELETE FROM event_forward_extremities + WHERE event_id != ? AND room_id = ? + """ + + txn.execute(sql, (event_id, room_id)) + logger.info( + "Deleted %s extra forward extremities for room %s", + txn.rowcount, + room_id, + ) + + if txn.rowcount > 0: + # Invalidate the cache + self._invalidate_cache_and_stream( + txn, self.get_latest_event_ids_in_room, (room_id,), + ) + + return txn.rowcount + + return await self.db_pool.runInteraction( + "delete_forward_extremities_for_room", + delete_forward_extremities_for_room_txn, + ) + + async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]: + """Get list of forward extremities for a room.""" + + def get_forward_extremities_for_room_txn(txn): + sql = """ + SELECT event_id, state_group, depth, received_ts + FROM event_forward_extremities + INNER JOIN event_to_state_groups USING (event_id) + INNER JOIN events USING (room_id, event_id) + WHERE room_id = ? + """ + + txn.execute(sql, (room_id,)) + return self.db_pool.cursor_to_dict(txn) + + return await self.db_pool.runInteraction( + "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, + ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 71d823be72..5c4c251871 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -525,7 +525,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: List[EventTypes], + state_types_to_include: List[str], membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_origin = ? AND media_id = ?" ) - txn.executemany( + txn.execute_batch( sql, ( (time_ms, media_origin, media_id) @@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): " WHERE media_id = ?" ) - txn.executemany(sql, ((time_ms, media_id) for media_id in local_media)) + txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) return await self.db_pool.runInteraction( "update_cached_last_access_time", update_cache_txn @@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" def _delete_url_cache_txn(txn): - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache", _delete_url_cache_txn @@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): def _delete_url_cache_media_txn(txn): sql = "DELETE FROM local_media_repository WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" - txn.executemany(sql, [(media_id,) for media_id in media_ids]) + txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) return await self.db_pool.runInteraction( "delete_url_cache_media", _delete_url_cache_media_txn diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index ab18cc4d79..92e65aa640 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py
@@ -88,6 +88,62 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (x[0] - 1) * x[1] for x in res if x[1] ) + async def count_daily_e2ee_messages(self): + """ + Returns an estimate of the number of messages sent in the last day. + + If it has been significantly less or more than one day since the last + call to this function, it will return None. + """ + + def _count_messages(txn): + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.encrypted' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) + + async def count_daily_sent_e2ee_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.encrypted' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction( + "count_daily_sent_e2ee_messages", _count_messages + ) + + async def count_daily_active_e2ee_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.encrypted' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + (count,) = txn.fetchone() + return count + + return await self.db_pool.runInteraction( + "count_daily_active_e2ee_rooms", _count + ) + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 54ef0f1f54..4360dc0afc 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo +from synapse.types import UserID +from synapse.util.caches.descriptors import cached + +BATCH_SIZE = 100 class ProfileWorkerStore(SQLBaseStore): @@ -39,6 +44,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) + @cached(max_entries=5000) async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", @@ -47,6 +53,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) + @cached(max_entries=5000) async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", @@ -55,6 +62,58 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + async def get_latest_profile_replication_batch_number(self): + def f(txn): + txn.execute("SELECT MAX(batch) as maxbatch FROM profiles") + rows = self.db_pool.cursor_to_dict(txn) + return rows[0]["maxbatch"] + + return await self.db_pool.runInteraction( + "get_latest_profile_replication_batch_number", f + ) + + async def get_profile_batch(self, batchnum): + return await self.db_pool.simple_select_list( + table="profiles", + keyvalues={"batch": batchnum}, + retcols=("user_id", "displayname", "avatar_url", "active"), + desc="get_profile_batch", + ) + + async def assign_profile_batch(self): + def f(txn): + sql = ( + "UPDATE profiles SET batch = " + "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) " + "WHERE user_id in (" + " SELECT user_id FROM profiles WHERE batch is NULL limit ?" + ")" + ) + txn.execute(sql, (BATCH_SIZE,)) + return txn.rowcount + + return await self.db_pool.runInteraction("assign_profile_batch", f) + + async def get_replication_hosts(self): + def f(txn): + txn.execute( + "SELECT host, last_synced_batch FROM profile_replication_status" + ) + rows = self.db_pool.cursor_to_dict(txn) + return {r["host"]: r["last_synced_batch"] for r in rows} + + return await self.db_pool.runInteraction("get_replication_hosts", f) + + async def update_replication_batch_for_host( + self, host: str, last_synced_batch: int + ): + return await self.db_pool.simple_upsert( + table="profile_replication_status", + keyvalues={"host": host}, + values={"last_synced_batch": last_synced_batch}, + desc="update_replication_batch_for_host", + ) + async def get_from_remote_profile_cache( self, user_id: str ) -> Optional[Dict[str, Any]]: @@ -72,32 +131,95 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: Optional[str] + self, user_localpart: str, new_displayname: Optional[str], batchnum: int ) -> None: - await self.db_pool.simple_update_one( + # Invalidate the read cache for this user + self.get_profile_displayname.invalidate((user_localpart,)) + + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, + values={"displayname": new_displayname, "batch": batchnum}, desc="set_profile_displayname", + lock=False, # we can do this because user_id has a unique index ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: Optional[str] + self, user_localpart: str, new_avatar_url: Optional[str], batchnum: int ) -> None: - await self.db_pool.simple_update_one( + # Invalidate the read cache for this user + self.get_profile_avatar_url.invalidate((user_localpart,)) + + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url, "batch": batchnum}, desc="set_profile_avatar_url", + lock=False, # we can do this because user_id has a unique index + ) + + async def set_profiles_active( + self, users: List[UserID], active: bool, hide: bool, batchnum: int, + ) -> None: + """Given a set of users, set active and hidden flags on them. + + Args: + users: A list of UserIDs + active: Whether to set the users to active or inactive + hide: Whether to hide the users (withold from replication). If + False and active is False, users will have their profiles + erased + batchnum: The batch number, used for profile replication + """ + # Convert list of localparts to list of tuples containing localparts + user_localparts = [(user.localpart,) for user in users] + + # Generate list of value tuples for each user + value_names = ("active", "batch") + values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple] + + if not active and not hide: + # we are deactivating for real (not in hide mode) + # so clear the profile information + value_names += ("avatar_url", "displayname") + values = [v + (None, None) for v in values] + + return await self.db_pool.runInteraction( + "set_profiles_active", + self.db_pool.simple_upsert_many_txn, + table="profiles", + key_names=("user_id",), + key_values=user_localparts, + value_names=value_names, + value_values=values, + ) + + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + await self.db_pool.simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", ) async def update_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> int: - return await self.db_pool.simple_update( + return await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, - updatevalues={ + values={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), @@ -166,6 +288,17 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): + def __init__(self, database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "profile_replication_status_host_index", + index_name="profile_replication_status_idx", + table="profile_replication_status", + columns=["host"], + unique=True, + ) + async def add_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> None: diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): ) # Update backward extremeties - txn.executemany( + txn.execute_batch( "INSERT INTO event_backward_extremities (room_id, event_id)" " VALUES (?, ?)", [(room_id, event_id) for event_id, in new_backwards_extrems], diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore): txn, self.get_if_user_has_pusher, (user_id,) ) - self.db_pool.simple_delete_one_txn( + # It is expected that there is exactly one pusher to delete, but + # if it isn't there (or there are multiple) delete them all. + self.db_pool.simple_delete_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 14c0878d81..269eb6e6e7 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -82,12 +82,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): database.engine, find_max_generated_user_id_localpart, "user_id_seq", ) - self._account_validity = hs.config.account_validity - if hs.config.run_background_tasks and self._account_validity.enabled: - self._clock.call_later( - 0.0, self._set_expiration_date_when_missing, + self._account_validity_enabled = hs.config.account_validity_enabled + if self._account_validity_enabled: + self._account_validity_period = hs.config.account_validity_period + self._account_validity_startup_job_max_delta = ( + hs.config.account_validity_startup_job_max_delta ) + if hs.config.run_background_tasks: + self._clock.call_later( + 0.0, self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: self._clock.looping_call( @@ -183,6 +189,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts: int, email_sent: bool, renewal_token: Optional[str] = None, + token_used_ts: Optional[int] = None, ) -> None: """Updates the account validity properties of the given account, with the given values. @@ -196,6 +203,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): period. renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. + token_used_ts: A timestamp of when the current token was used to renew + the account. """ def set_account_validity_for_user_txn(txn): @@ -207,6 +216,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "expiration_ts_ms": expiration_ts, "email_sent": email_sent, "renewal_token": renewal_token, + "token_used_ts_ms": token_used_ts, }, ) self._invalidate_cache_and_stream( @@ -217,10 +227,41 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "set_account_validity_for_user", set_account_validity_for_user_txn ) + async def get_expired_users(self): + """Get UserIDs of all expired users. + + Users who are not active, or do not have profile information, are + excluded from the results. + + Returns: + Deferred[List[UserID]]: List of expired user IDs + """ + + def get_expired_users_txn(txn, now_ms): + # We need to use pattern matching as profiles.user_id is confusingly just the + # user's localpart, whereas account_validity.user_id is a full user ID + sql = """ + SELECT av.user_id from account_validity AS av + LEFT JOIN profiles as p + ON av.user_id LIKE '%%' || p.user_id || ':%%' + WHERE expiration_ts_ms <= ? + AND p.active = 1 + """ + txn.execute(sql, (now_ms,)) + rows = txn.fetchall() + + return [UserID.from_string(row[0]) for row in rows] + + res = await self.db_pool.runInteraction( + "get_expired_users", get_expired_users_txn, self._clock.time_msec() + ) + + return res + async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: - """Defines a renewal token for a given user. + """Defines a renewal token for a given user, and clears the token_used timestamp. Args: user_id: ID of the user to set the renewal token for. @@ -233,26 +274,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, + updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None}, desc="set_renewal_token_for_user", ) - async def get_user_from_renewal_token(self, renewal_token: str) -> str: - """Get a user ID from a renewal token. + async def get_user_from_renewal_token( + self, renewal_token: str + ) -> Tuple[str, int, Optional[int]]: + """Get a user ID and renewal status from a renewal token. Args: renewal_token: The renewal token to perform the lookup with. Returns: - The ID of the user to which the token belongs. + A tuple of containing the following values: + * The ID of a user to which the token belongs. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. + * An optional int representing the timestamp of when the user renewed their + account timestamp as milliseconds since the epoch. None if the account + has not been renewed using the current token yet. """ - return await self.db_pool.simple_select_one_onecol( + ret_dict = await self.db_pool.simple_select_one( table="account_validity", keyvalues={"renewal_token": renewal_token}, - retcol="user_id", + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], desc="get_user_from_renewal_token", ) + return ( + ret_dict["user_id"], + ret_dict["expiration_ts_ms"], + ret_dict["token_used_ts_ms"], + ) + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. @@ -291,7 +346,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity.renew_at, + self.config.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: @@ -323,6 +378,54 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="delete_account_validity_for_user", ) + async def get_info_for_users( + self, user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = await self.db_pool.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self._clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. @@ -360,6 +463,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) + async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None: + """Sets whether a user shadow-banned. + + Args: + user: user ID of the user to test + shadow_banned: true iff the user is to be shadow-banned, false otherwise. + """ + + def set_shadow_banned_txn(txn): + self.db_pool.simple_update_one_txn( + txn, + table="users", + keyvalues={"name": user.to_string()}, + updatevalues={"shadow_banned": shadow_banned}, + ) + # In order for this to apply immediately, clear the cache for this user. + tokens = self.db_pool.simple_select_onecol_txn( + txn, + table="access_tokens", + keyvalues={"user_id": user.to_string()}, + retcol="token", + ) + for token in tokens: + self._invalidate_cache_and_stream( + txn, self.get_user_by_access_token, (token,) + ) + + await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, @@ -922,11 +1054,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period + expiration_ts = now_ms + self._account_validity_period if use_delta: expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) @@ -1124,7 +1256,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): FROM user_threepids """ - txn.executemany(sql, [(id_server,) for id_server in id_servers]) + txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) if id_servers: await self.db_pool.runInteraction( @@ -1364,7 +1496,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): except self.database_engine.module.IntegrityError: raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - if self._account_validity.enabled: + if self._account_validity_enabled: self.set_expiration_date_for_user_txn(txn, user_id) if create_profile_with_displayname: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a9fcb5f59c..a98b423771 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import logging from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore, db_to_json @@ -178,11 +177,13 @@ class RoomWorkerStore(SQLBaseStore): INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR history_visibility = 'world_readable' + join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + OR history_visibility = 'world_readable' ) AND joined_members > 0 """ % { - "published_sql": published_sql + "published_sql": published_sql, + "knock_join_rule": JoinRules.KNOCK, } txn.execute(sql, query_args) @@ -305,7 +306,7 @@ class RoomWorkerStore(SQLBaseStore): sql = """ SELECT room_id, name, topic, canonical_alias, joined_members, - avatar, history_visibility, joined_members, guest_access + avatar, history_visibility, guest_access, join_rules FROM ( %(published_sql)s ) published @@ -313,7 +314,8 @@ class RoomWorkerStore(SQLBaseStore): INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR history_visibility = 'world_readable' + join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + OR history_visibility = 'world_readable' ) AND joined_members > 0 %(where_clause)s @@ -322,6 +324,7 @@ class RoomWorkerStore(SQLBaseStore): "published_sql": published_sql, "where_clause": where_clause, "dir": "DESC" if forwards else "ASC", + "knock_join_rule": JoinRules.KNOCK, } if limit is not None: @@ -356,6 +359,23 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def is_room_published(self, room_id: str) -> bool: + """Check whether a room has been published in the local public room + directory. + + Args: + room_id + Returns: + Whether the room is currently published in the room directory + """ + # Get room information + room_info = await self.get_room(room_id) + if not room_info: + return False + + # Check the is_public value + return room_info.get("is_public", False) + async def get_rooms_paginate( self, start: int, @@ -564,6 +584,11 @@ class RoomWorkerStore(SQLBaseStore): Returns: dict[int, int]: "min_lifetime" and "max_lifetime" for this room. """ + # If the room retention feature is disabled, return a policy with no minimum nor + # maximum, in order not to filter out events we should filter out when sending to + # the client. + if not self.config.retention_enabled: + return {"min_lifetime": None, "max_lifetime": None} def get_retention_policy_for_room_txn(txn): txn.execute( diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): "max_stream_id_exclusive", self._stream_order_on_start + 1 ) - INSERT_CLUMP_SIZE = 1000 - def add_membership_profile_txn(txn): sql = """ SELECT stream_ordering, event_id, events.room_id, event_json.json @@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): UPDATE room_memberships SET display_name = ?, avatar_url = ? WHERE event_id = ? AND room_id = ? """ - for index in range(0, len(to_update), INSERT_CLUMP_SIZE): - clump = to_update[index : index + INSERT_CLUMP_SIZE] - txn.executemany(to_update_sql, clump) + txn.execute_batch(to_update_sql, to_update) progress = { "target_min_stream_id_inclusive": target_min_stream_id, diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql new file mode 100644
index 0000000000..e744c02fe8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Add a batch number to track changes to profiles and the + * order they're made in so we can replicate user profiles + * to other hosts as they change + */ +ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL; + +/* + * Index on the batch number so we can get profiles + * by their batch + */ +CREATE INDEX profiles_batch_idx ON profiles(batch); + +/* + * A table to track what batch of user profiles has been + * synced to what profile replication target. + */ +CREATE TABLE profile_replication_status ( + host TEXT NOT NULL, + last_synced_batch BIGINT NOT NULL +); diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql new file mode 100644
index 0000000000..96051ac179 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * A flag saying whether the user owning the profile has been deactivated + * This really belongs on the users table, not here, but the users table + * stores users by their full user_id and profiles stores them by localpart, + * so we can't easily join between the two tables. Plus, the batch number + * realy ought to represent data in this table that has changed. + */ +ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql new file mode 100644
index 0000000000..7542ab8cbd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host); \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql new file mode 100644
index 0000000000..4836dac16e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql
@@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track when users renew their account using the value of the 'renewal_token' column. +-- This field should be set to NULL after a fresh token is generated. +ALTER TABLE account_validity ADD token_used_ts_ms BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql b/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql new file mode 100644
index 0000000000..658f55a384 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 Sorunome + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_stats_current ADD knocked_members INT NOT NULL DEFAULT '0'; +ALTER TABLE room_stats_historical ADD knocked_members BIGINT NOT NULL DEFAULT '0'; diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644 --- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py +++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs # { "ignored_users": "@someone:example.org": {} } ignored_users = content.get("ignored_users", {}) if isinstance(ignored_users, dict) and ignored_users: - cur.executemany(insert_sql, [(user_id, u) for u in ignored_users]) + cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users]) # Add indexes after inserting data for efficiency. logger.info("Adding constraints to ignored_users table") diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream ( +CREATE TABLE profile_replication_status ( + host text NOT NULL, + last_synced_batch bigint NOT NULL +); + + + CREATE TABLE profiles ( user_id text NOT NULL, displayname text, - avatar_url text + avatar_url text, + batch bigint, + active smallint DEFAULT 1 NOT NULL ); @@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); +CREATE INDEX profiles_batch_idx ON profiles USING btree (batch); + + + CREATE INDEX public_room_index ON rooms USING btree (is_public); diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); -CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); +CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) ); CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); @@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id); CREATE INDEX group_invites_u_idx ON group_invites(user_id); CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +CREATE INDEX profiles_batch_idx ON profiles(batch); +CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL ); CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e34fce6281..f5e7d9ef98 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import Collection logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) elif isinstance(self.database_engine, Sqlite3Engine): sql = ( @@ -75,7 +76,7 @@ class SearchWorkerStore(SQLBaseStore): for entry in entries ) - txn.executemany(sql, args) + txn.execute_batch(sql, args) else: # This should be unreachable. raise Exception("Unrecognized database engine") @@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore): async def search_rooms( self, - room_ids: List[str], + room_ids: Collection[str], search_term: str, keys: List[str], limit, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0cdb3ec1f7..4ad363fb0d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -15,11 +15,12 @@ # limitations under the License. import logging -from collections import Counter from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Tuple +from typing_extensions import Counter + from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -41,6 +42,7 @@ ABSOLUTE_STATS_FIELDS = { "current_state_events", "joined_members", "invited_members", + "knocked_members", "left_members", "banned_members", "local_users_in_room", @@ -319,7 +321,9 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: + async def get_earliest_token_for_stats( + self, stats_type: str, id: str + ) -> Optional[int]: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -339,7 +343,7 @@ class StatsStore(StateDeltasStore): ) async def bulk_update_stats_delta( - self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. @@ -665,7 +669,7 @@ class StatsStore(StateDeltasStore): async def get_changes_room_total_events_and_bytes( self, min_pos: int, max_pos: int - ) -> Dict[str, Dict[str, int]]: + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Fetches the counts of events in the given range of stream IDs. Args: @@ -683,18 +687,19 @@ class StatsStore(StateDeltasStore): max_pos, ) - def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + def get_changes_room_total_events_and_bytes_txn( + self, txn, low_pos: int, high_pos: int + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Gets the total_events and total_event_bytes counts for rooms and senders, in a range of stream_orderings (including backfilled events). Args: txn - low_pos (int): Low stream ordering - high_pos (int): High stream ordering + low_pos: Low stream ordering + high_pos: High stream ordering Returns: - tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The - room and user deltas for total_events/total_event_bytes in the + The room and user deltas for total_events/total_event_bytes in the format of `stats_id` -> fields """ diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index ef11f1c3b3..336da218ed 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - async def update_user_directory_stream_pos(self, stream_id: str) -> None: + async def update_user_directory_stream_pos(self, stream_id: int) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, @@ -558,6 +558,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._prefer_local_users_in_search = ( + hs.config.user_directory_search_prefer_local_users + ) + self._server_name = hs.config.server_name + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( @@ -750,9 +755,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + # We allow manipulating the ranking algorithm by injecting statements + # based on config options. + additional_ordering_statements = [] + ordering_arguments = () + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)" + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + # We order by rank and then if they have profile info # The ranking algorithm is hand tweaked for "best" results. Broadly # the idea is we give a higher weight to exact matches. @@ -763,7 +783,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -783,33 +803,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 8 ) ) + %(order_case_statements)s DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, + """ % { + "where_clause": where_clause, + "order_case_statements": " ".join(additional_ordering_statements), + } + args = ( + join_args + + (full_query, exact_query, prefix_query) + + ordering_arguments + + (limit + 1,) ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + # + # Note that we need to include a comma at the end for valid SQL + statement = "user_id LIKE ? DESC," + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + sql = """ SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, + %(order_statements)s display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) + """ % { + "where_clause": where_clause, + "order_statements": " ".join(additional_ordering_statements), + } + args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) logger.info("[purge] removing redundant state groups") - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups_state WHERE state_group = ?", ((sg,) for sg in state_groups_to_delete), ) - txn.executemany( + txn.execute_batch( "DELETE FROM state_groups WHERE id = ?", ((sg,) for sg in state_groups_to_delete), ) diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..71ef5a72dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@ import heapq import logging import threading -from collections import deque +from collections import OrderedDict from contextlib import contextmanager from typing import Dict, List, Optional, Set, Tuple, Union import attr -from typing_extensions import Deque from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction @@ -101,7 +100,13 @@ class StreamIdGenerator: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() # type: Deque[int] + + # We use this as an ordered set, as we want to efficiently append items, + # remove items and get the first item. Since we insert IDs in order, the + # insertion ordering will ensure its in the correct ordering. + # + # The key and values are the same, but we never look at the values. + self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] def get_next(self): """ @@ -113,7 +118,7 @@ class StreamIdGenerator: self._current += self._step next_id = self._current - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -121,7 +126,7 @@ class StreamIdGenerator: yield next_id finally: with self._lock: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -140,7 +145,7 @@ class StreamIdGenerator: self._current += n * self._step for next_id in next_ids: - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -149,7 +154,7 @@ class StreamIdGenerator: finally: with self._lock: for next_id in next_ids: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -162,7 +167,7 @@ class StreamIdGenerator: """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._step + return next(iter(self._unfinished_ids)) - self._step return self._current diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..0ec4dc2918 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta): ... @abc.abstractmethod + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + """Get the next `n` IDs in the sequence""" + ... + + @abc.abstractmethod def check_consistency( self, db_conn: "LoggingDatabaseConnection", @@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + first_id = self._current_max_id + 1 + self._current_max_id += n + return [first_id + i for i in range(n)] + def check_consistency( self, db_conn: Connection, diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py new file mode 100644
index 0000000000..4589e4539b --- /dev/null +++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,971 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import email.utils +import logging +from typing import Dict, List, Optional, Tuple + +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.api.errors import SynapseError +from synapse.config._base import ConfigError +from synapse.events import EventBase +from synapse.module_api import ModuleApi +from synapse.types import Requester, StateMap, UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + +ACCESS_RULES_TYPE = "im.vector.room.access_rules" + + +class AccessRules: + DIRECT = "direct" + RESTRICTED = "restricted" + UNRESTRICTED = "unrestricted" + + +VALID_ACCESS_RULES = ( + AccessRules.DIRECT, + AccessRules.RESTRICTED, + AccessRules.UNRESTRICTED, +) + +# Rules to which we need to apply the power levels restrictions. +# +# These are all of the rules that neither: +# * forbid users from joining based on a server blacklist (which means that there +# is no need to apply power level restrictions), nor +# * target direct chats (since we allow both users to be room admins in this case). +# +# The power-level restrictions, when they are applied, prevent the following: +# * the default power level for users (users_default) being set to anything other than 0. +# * a non-default power level being assigned to any user which would be forbidden from +# joining a restricted room. +RULES_WITH_RESTRICTED_POWER_LEVELS = (AccessRules.UNRESTRICTED,) + + +class RoomAccessRules(object): + """Implementation of the ThirdPartyEventRules module API that allows federation admins + to define custom rules for specific events and actions. + Implements the custom behaviour for the "im.vector.room.access_rules" state event. + + Takes a config in the format: + + third_party_event_rules: + module: third_party_rules.RoomAccessRules + config: + # List of domains (server names) that can't be invited to rooms if the + # "restricted" rule is set. Defaults to an empty list. + domains_forbidden_when_restricted: [] + + # Identity server to use when checking the HS an email address belongs to + # using the /info endpoint. Required. + id_server: "vector.im" + + # Enable freezing a room when the last room admin leaves. + # Note that the departing admin must be a local user in order for this feature + # to work. + freeze_room_with_no_admin: false + + Don't forget to consider if you can invite users from your own domain. + """ + + def __init__( + self, config: Dict, module_api: ModuleApi, + ): + self.id_server = config["id_server"] + self.module_api = module_api + + self.domains_forbidden_when_restricted = config.get( + "domains_forbidden_when_restricted", [] + ) + + self.freeze_room_with_no_admin = config.get("freeze_room_with_no_admin", False) + + @staticmethod + def parse_config(config: Dict) -> Dict: + """Parses and validates the options specified in the homeserver config. + + Args: + config: The config dict. + + Returns: + The config dict. + + Raises: + ConfigError: If there was an issue with the provided module configuration. + """ + if "id_server" not in config: + raise ConfigError("No IS for event rules TchapEventRules") + + return config + + async def on_create_room( + self, requester: Requester, config: Dict, is_requester_admin: bool, + ) -> bool: + """Implements synapse.events.ThirdPartyEventRules.on_create_room. + + Checks if a im.vector.room.access_rules event is being set during room creation. + If yes, make sure the event is correct. Otherwise, append an event with the + default rule to the initial state. + + Checks if a m.rooms.power_levels event is being set during room creation. + If yes, make sure the event is allowed. Otherwise, set power_level_content_override + in the config dict to our modified version of the default room power levels. + + Args: + requester: The user who is making the createRoom request. + config: The createRoom config dict provided by the user. + is_requester_admin: Whether the requester is a Synapse admin. + + Returns: + Whether the request is allowed. + + Raises: + SynapseError: If the createRoom config dict is invalid or its contents blocked. + """ + is_direct = config.get("is_direct") + preset = config.get("preset") + access_rule = None + join_rule = None + + # If there's a rules event in the initial state, check if it complies with the + # spec for im.vector.room.access_rules and deny the request if not. + for event in config.get("initial_state", []): + if event["type"] == ACCESS_RULES_TYPE: + access_rule = event["content"].get("rule") + + # Make sure the event has a valid content. + if access_rule is None: + raise SynapseError(400, "Invalid access rule") + + # Make sure the rule name is valid. + if access_rule not in VALID_ACCESS_RULES: + raise SynapseError(400, "Invalid access rule") + + if (is_direct and access_rule != AccessRules.DIRECT) or ( + access_rule == AccessRules.DIRECT and not is_direct + ): + raise SynapseError(400, "Invalid access rule") + + if event["type"] == EventTypes.JoinRules: + join_rule = event["content"].get("join_rule") + + if access_rule is None: + # If there's no access rules event in the initial state, create one with the + # default setting. + if is_direct: + default_rule = AccessRules.DIRECT + else: + # If the default value for non-direct chat changes, we should make another + # case here for rooms created with either a "public" join_rule or the + # "public_chat" preset to make sure those keep defaulting to "restricted" + default_rule = AccessRules.RESTRICTED + + if not config.get("initial_state"): + config["initial_state"] = [] + + config["initial_state"].append( + { + "type": ACCESS_RULES_TYPE, + "state_key": "", + "content": {"rule": default_rule}, + } + ) + + access_rule = default_rule + + # Check that the preset in use is compatible with the access rule, whether it's + # user-defined or the default. + # + # Direct rooms may not have their join_rules set to JoinRules.PUBLIC. + if ( + join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT + ) and access_rule == AccessRules.DIRECT: + raise SynapseError(400, "Invalid access rule") + + default_power_levels = self._get_default_power_levels( + requester.user.to_string() + ) + + # Check if the creator can override values for the power levels. + allowed = self._is_power_level_content_allowed( + config.get("power_level_content_override", {}), + access_rule, + default_power_levels, + ) + if not allowed: + raise SynapseError(400, "Invalid power levels content override") + + custom_user_power_levels = config.get("power_level_content_override") + + # Second loop for events we need to know the current rule to process. + for event in config.get("initial_state", []): + if event["type"] == EventTypes.PowerLevels: + allowed = self._is_power_level_content_allowed( + event["content"], access_rule, default_power_levels + ) + if not allowed: + raise SynapseError(400, "Invalid power levels content") + + custom_user_power_levels = event["content"] + if custom_user_power_levels: + # If the user is using their own power levels, but failed to provide an expected + # key in the power levels content dictionary, fill it in from the defaults instead + for key, value in default_power_levels.items(): + custom_user_power_levels.setdefault(key, value) + else: + # If power levels were not overridden by the user, completely override with the + # defaults instead + config["power_level_content_override"] = default_power_levels + + return True + + # If power levels are not overridden by the user during room creation, the following + # rules are used instead. Changes from Synapse's default power levels are noted. + # + # The same power levels are currently applied regardless of room preset. + @staticmethod + def _get_default_power_levels(user_id: str) -> Dict: + return { + "users": {user_id: 100}, + "users_default": 0, + "events": { + EventTypes.Name: 50, + EventTypes.PowerLevels: 100, + EventTypes.RoomHistoryVisibility: 100, + EventTypes.CanonicalAlias: 50, + EventTypes.RoomAvatar: 50, + EventTypes.Tombstone: 100, + EventTypes.ServerACL: 100, + EventTypes.RoomEncryption: 100, + }, + "events_default": 0, + "state_default": 100, # Admins should be the only ones to perform other tasks + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 50, # All rooms should require mod to invite, even private + } + + async def check_threepid_can_be_invited( + self, medium: str, address: str, state_events: StateMap[EventBase], + ) -> bool: + """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited. + + Check if a threepid can be invited to the room via a 3PID invite given the current + rules and the threepid's address, by retrieving the HS it's mapped to from the + configured identity server, and checking if we can invite users from it. + + Args: + medium: The medium of the threepid. + address: The address of the threepid. + state_events: A dict mapping (event type, state key) to state event. + State events in the room the threepid is being invited to. + + Returns: + Whether the threepid invite is allowed. + """ + rule = self._get_rule_from_state(state_events) + + if medium != "email": + return False + + if rule != AccessRules.RESTRICTED: + # Only "restricted" requires filtering 3PID invites. We don't need to do + # anything for "direct" here, because only "restricted" requires filtering + # based on the HS the address is mapped to. + return True + + parsed_address = email.utils.parseaddr(address)[1] + if parsed_address != address: + # Avoid reproducing the security issue described here: + # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2 + # It's probably not worth it but let's just be overly safe here. + return False + + # Get the HS this address belongs to from the identity server. + res = await self.module_api.http_client.get_json( + "https://%s/_matrix/identity/api/v1/info" % (self.id_server,), + {"medium": medium, "address": address}, + ) + + # Look for a domain that's not forbidden from being invited. + if not res.get("hs"): + return False + if res.get("hs") in self.domains_forbidden_when_restricted: + return False + + return True + + async def check_event_allowed( + self, event: EventBase, state_events: StateMap[EventBase], + ) -> bool: + """Implements synapse.events.ThirdPartyEventRules.check_event_allowed. + + Checks the event's type and the current rule and calls the right function to + determine whether the event can be allowed. + + Args: + event: The event to check. + state_events: A dict mapping (event type, state key) to state event. + State events in the room the event originated from. + + Returns: + True if the event can be allowed, False otherwise. + """ + if event.type == ACCESS_RULES_TYPE: + return await self._on_rules_change(event, state_events) + + # We need to know the rule to apply when processing the event types below. + rule = self._get_rule_from_state(state_events) + + if event.type == EventTypes.PowerLevels: + return self._is_power_level_content_allowed( + event.content, rule, on_room_creation=False + ) + + if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite: + return await self._on_membership_or_invite(event, rule, state_events) + + if event.type == EventTypes.JoinRules: + return self._on_join_rule_change(event, rule) + + if event.type == EventTypes.RoomAvatar: + return self._on_room_avatar_change(event, rule) + + if event.type == EventTypes.Name: + return self._on_room_name_change(event, rule) + + if event.type == EventTypes.Topic: + return self._on_room_topic_change(event, rule) + + return True + + async def check_visibility_can_be_modified( + self, room_id: str, state_events: StateMap[EventBase], new_visibility: str + ) -> bool: + """Implements + synapse.events.ThirdPartyEventRules.check_visibility_can_be_modified + + Determines whether a room can be published, or removed from, the public room + list. A room is published if its visibility is set to "public". Otherwise, + its visibility is "private". A room with access rule other than "restricted" + may not be published. + + Args: + room_id: The ID of the room. + state_events: A dict mapping (event type, state key) to state event. + State events in the room. + new_visibility: The new visibility state. Either "public" or "private". + + Returns: + Whether the room is allowed to be published to, or removed from, the public + rooms directory. + """ + # We need to know the rule to apply when processing the event types below. + rule = self._get_rule_from_state(state_events) + + # Allow adding a room to the public rooms list only if it is restricted + if new_visibility == "public": + return rule == AccessRules.RESTRICTED + + # By default a room is created as "restricted", meaning it is allowed to be + # published to the public rooms directory. + return True + + async def _on_rules_change( + self, event: EventBase, state_events: StateMap[EventBase] + ): + """Checks whether an im.vector.room.access_rules event is forbidden or allowed. + + Args: + event: The im.vector.room.access_rules event. + state_events: A dict mapping (event type, state key) to state event. + State events in the room before the event was sent. + Returns: + True if the event can be allowed, False otherwise. + """ + new_rule = event.content.get("rule") + + # Check for invalid values. + if new_rule not in VALID_ACCESS_RULES: + return False + + # Make sure we don't apply "direct" if the room has more than two members. + if new_rule == AccessRules.DIRECT: + existing_members, threepid_tokens = self._get_members_and_tokens_from_state( + state_events + ) + + if len(existing_members) > 2 or len(threepid_tokens) > 1: + return False + + if new_rule != AccessRules.RESTRICTED: + # Block this change if this room is currently listed in the public rooms + # directory + if await self.module_api.public_room_list_manager.room_is_in_public_room_list( + event.room_id + ): + return False + + prev_rules_event = state_events.get((ACCESS_RULES_TYPE, "")) + + # Now that we know the new rule doesn't break the "direct" case, we can allow any + # new rule in rooms that had none before. + if prev_rules_event is None: + return True + + prev_rule = prev_rules_event.content.get("rule") + + # Currently, we can only go from "restricted" to "unrestricted". + return ( + prev_rule == AccessRules.RESTRICTED and new_rule == AccessRules.UNRESTRICTED + ) + + async def _on_membership_or_invite( + self, event: EventBase, rule: str, state_events: StateMap[EventBase], + ) -> bool: + """Applies the correct rule for incoming m.room.member and + m.room.third_party_invite events. + + Args: + event: The event to check. + rule: The name of the rule to apply. + state_events: A dict mapping (event type, state key) to state event. + The state of the room before the event was sent. + + Returns: + True if the event can be allowed, False otherwise. + """ + if rule == AccessRules.RESTRICTED: + ret = self._on_membership_or_invite_restricted(event) + elif rule == AccessRules.UNRESTRICTED: + ret = self._on_membership_or_invite_unrestricted(event, state_events) + elif rule == AccessRules.DIRECT: + ret = self._on_membership_or_invite_direct(event, state_events) + else: + # We currently apply the default (restricted) if we don't know the rule, we + # might want to change that in the future. + ret = self._on_membership_or_invite_restricted(event) + + if event.type == "m.room.member": + # If this is an admin leaving, and they are the last admin in the room, + # raise the power levels of the room so that the room is 'frozen'. + # + # We have to freeze the room by puppeting an admin user, which we can + # only do for local users + if ( + self.freeze_room_with_no_admin + and self._is_local_user(event.sender) + and event.membership == Membership.LEAVE + ): + await self._freeze_room_if_last_admin_is_leaving(event, state_events) + + return ret + + async def _freeze_room_if_last_admin_is_leaving( + self, event: EventBase, state_events: StateMap[EventBase] + ): + power_level_state_event = state_events.get( + (EventTypes.PowerLevels, "") + ) # type: EventBase + if not power_level_state_event: + return + power_level_content = power_level_state_event.content + + # Do some validation checks on the power level state event + if ( + not isinstance(power_level_content, dict) + or "users" not in power_level_content + or not isinstance(power_level_content["users"], dict) + ): + # We can't use this power level event to determine whether the room should be + # frozen. Bail out. + return + + user_id = event.get("sender") + if not user_id: + return + + # Get every admin user defined in the room's state + admin_users = { + user + for user, power_level in power_level_content["users"].items() + if power_level >= 100 + } + + if user_id not in admin_users: + # This user is not an admin, ignore them + return + + if any( + event_type == EventTypes.Member + and event.membership in [Membership.JOIN, Membership.INVITE] + and state_key in admin_users + and state_key != user_id + for (event_type, state_key), event in state_events.items() + ): + # There's another admin user in, or invited to, the room + return + + # Freeze the room by raising the required power level to send events to 100 + logger.info("Freezing room '%s'", event.room_id) + + # Modify the existing power levels to raise all required types to 100 + # + # This changes a power level state event's content from something like: + # { + # "redact": 50, + # "state_default": 50, + # "ban": 50, + # "notifications": { + # "room": 50 + # }, + # "events": { + # "m.room.avatar": 50, + # "m.room.encryption": 50, + # "m.room.canonical_alias": 50, + # "m.room.name": 50, + # "im.vector.modular.widgets": 50, + # "m.room.topic": 50, + # "m.room.tombstone": 50, + # "m.room.history_visibility": 100, + # "m.room.power_levels": 100 + # }, + # "users_default": 0, + # "events_default": 0, + # "users": { + # "@admin:example.com": 100, + # }, + # "kick": 50, + # "invite": 0 + # } + # + # to + # + # { + # "redact": 100, + # "state_default": 100, + # "ban": 100, + # "notifications": { + # "room": 50 + # }, + # "events": {} + # "users_default": 0, + # "events_default": 100, + # "users": { + # "@admin:example.com": 100, + # }, + # "kick": 100, + # "invite": 100 + # } + new_content = {} + for key, value in power_level_content.items(): + # Do not change "users_default", as that key specifies the default power + # level of new users + if isinstance(value, int) and key != "users_default": + value = 100 + new_content[key] = value + + # Set some values in case they are missing from the original + # power levels event content + new_content.update( + { + # Clear out any special-cased event keys + "events": {}, + # Ensure state_default and events_default keys exist and are 100. + # Otherwise a lower PL user could potentially send state events that + # aren't explicitly mentioned elsewhere in the power level dict + "state_default": 100, + "events_default": 100, + # Membership events default to 50 if they aren't present. Set them + # to 100 here, as they would be set to 100 if they were present anyways + "ban": 100, + "kick": 100, + "invite": 100, + "redact": 100, + } + ) + + await self.module_api.create_and_send_event_into_room( + { + "room_id": event.room_id, + "sender": user_id, + "type": EventTypes.PowerLevels, + "content": new_content, + "state_key": "", + } + ) + + def _on_membership_or_invite_restricted(self, event: EventBase) -> bool: + """Implements the checks and behaviour specified for the "restricted" rule. + + "restricted" currently means that users can only invite users if their server is + included in a limited list of domains. + + Args: + event: The event to check. + + Returns: + True if the event can be allowed, False otherwise. + """ + # We're not applying the rules on m.room.third_party_member events here because + # the filtering on threepids is done in check_threepid_can_be_invited, which is + # called before check_event_allowed. + if event.type == EventTypes.ThirdPartyInvite: + return True + + # We only need to process "join" and "invite" memberships, in order to be backward + # compatible, e.g. if a user from a blacklisted server joined a restricted room + # before the rules started being enforced on the server, that user must be able to + # leave it. + if event.membership not in [Membership.JOIN, Membership.INVITE]: + return True + + invitee_domain = get_domain_from_id(event.state_key) + return invitee_domain not in self.domains_forbidden_when_restricted + + def _on_membership_or_invite_unrestricted( + self, event: EventBase, state_events: StateMap[EventBase] + ) -> bool: + """Implements the checks and behaviour specified for the "unrestricted" rule. + + "unrestricted" currently means that forbidden users cannot join without an invite. + + Returns: + True if the event can be allowed, False otherwise. + """ + # If this is a join from a forbidden user and they don't have an invite to the + # room, then deny it + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + # Check if this user is from a forbidden server + target_domain = get_domain_from_id(event.state_key) + if target_domain in self.domains_forbidden_when_restricted: + # If so, they'll need an invite to join this room. Check if one exists + if not self._user_is_invited_to_room(event.state_key, state_events): + return False + + return True + + def _on_membership_or_invite_direct( + self, event: EventBase, state_events: StateMap[EventBase], + ) -> bool: + """Implements the checks and behaviour specified for the "direct" rule. + + "direct" currently means that no member is allowed apart from the two initial + members the room was created for (i.e. the room's creator and their first invitee). + + Args: + event: The event to check. + state_events: A dict mapping (event type, state key) to state event. + The state of the room before the event was sent. + + Returns: + True if the event can be allowed, False otherwise. + """ + # Get the room memberships and 3PID invite tokens from the room's state. + existing_members, threepid_tokens = self._get_members_and_tokens_from_state( + state_events + ) + + # There should never be more than one 3PID invite in the room state: if the second + # original user came and left, and we're inviting them using their email address, + # given we know they have a Matrix account binded to the address (so they could + # join the first time), Synapse will successfully look it up before attempting to + # store an invite on the IS. + if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite: + # If we already have a 3PID invite in flight, don't accept another one, unless + # the new one has the same invite token as its state key. This is because 3PID + # invite revocations must be allowed, and a revocation is basically a new 3PID + # invite event with an empty content and the same token as the invite it + # revokes. + return event.state_key in threepid_tokens + + if len(existing_members) == 2: + # If the user was within the two initial user of the room, Synapse would have + # looked it up successfully and thus sent a m.room.member here instead of + # m.room.third_party_invite. + if event.type == EventTypes.ThirdPartyInvite: + return False + + # We can only have m.room.member events here. The rule in this case is to only + # allow the event if its target is one of the initial two members in the room, + # i.e. the state key of one of the two m.room.member states in the room. + return event.state_key in existing_members + + # We're alone in the room (and always have been) and there's one 3PID invite in + # flight. + if len(existing_members) == 1 and len(threepid_tokens) == 1: + # We can only have m.room.member events here. In this case, we can only allow + # the event if it's either a m.room.member from the joined user (we can assume + # that the only m.room.member event is a join otherwise we wouldn't be able to + # send an event to the room) or an an invite event which target is the invited + # user. + target = event.state_key + is_from_threepid_invite = self._is_invite_from_threepid( + event, threepid_tokens[0] + ) + return is_from_threepid_invite or target == existing_members[0] + + return True + + def _is_power_level_content_allowed( + self, + content: Dict, + access_rule: str, + default_power_levels: Optional[Dict] = None, + on_room_creation: bool = True, + ) -> bool: + """Check if a given power levels event is permitted under the given access rule. + + It shouldn't be allowed if it either changes the default PL to a non-0 value or + gives a non-0 PL to a user that would have been forbidden from joining the room + under a more restrictive access rule. + + Args: + content: The content of the m.room.power_levels event to check. + access_rule: The access rule in place in this room. + default_power_levels: The default power levels when a room is created with + the specified access rule. Required if on_room_creation is True. + on_room_creation: True if this call is happening during a room's + creation, False otherwise. + + Returns: + Whether the content of the power levels event is valid. + """ + # Only enforce these rules during room creation + # + # We want to allow admins to modify or fix the power levels in a room if they + # have a special circumstance, but still want to encourage a certain pattern during + # room creation. + if on_room_creation: + # We specifically don't fail if "invite" or "state_default" are None, as those + # values should be replaced with our "default" power level values anyways, + # which are compliant + + invite = default_power_levels["invite"] + state_default = default_power_levels["state_default"] + + # If invite requirements are less than our required defaults + if content.get("invite", invite) < invite: + return False + + # If "other" state requirements are less than our required defaults + if content.get("state_default", state_default) < state_default: + return False + + # Check if we need to apply the restrictions with the current rule. + if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS: + return True + + # If users_default is explicitly set to a non-0 value, deny the event. + users_default = content.get("users_default", 0) + if users_default: + return False + + users = content.get("users", {}) + for user_id, power_level in users.items(): + server_name = get_domain_from_id(user_id) + # Check the domain against the blacklist. If found, and the PL isn't 0, deny + # the event. + if ( + server_name in self.domains_forbidden_when_restricted + and power_level != 0 + ): + return False + + return True + + def _on_join_rule_change(self, event: EventBase, rule: str) -> bool: + """Check whether a join rule change is allowed. + + A join rule change is always allowed unless the new join rule is "public" and + the current access rule is "direct". + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + Whether the change is allowed. + """ + if event.content.get("join_rule") == JoinRules.PUBLIC: + return rule != AccessRules.DIRECT + + return True + + def _on_room_avatar_change(self, event: EventBase, rule: str) -> bool: + """Check whether a change of room avatar is allowed. + The current rule is to forbid such a change in direct chats but allow it + everywhere else. + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + True if the event can be allowed, False otherwise. + """ + return rule != AccessRules.DIRECT + + def _on_room_name_change(self, event: EventBase, rule: str) -> bool: + """Check whether a change of room name is allowed. + The current rule is to forbid such a change in direct chats but allow it + everywhere else. + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + True if the event can be allowed, False otherwise. + """ + return rule != AccessRules.DIRECT + + def _on_room_topic_change(self, event: EventBase, rule: str) -> bool: + """Check whether a change of room topic is allowed. + The current rule is to forbid such a change in direct chats but allow it + everywhere else. + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + True if the event can be allowed, False otherwise. + """ + return rule != AccessRules.DIRECT + + @staticmethod + def _get_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]: + """Extract the rule to be applied from the given set of state events. + + Args: + state_events: A dict mapping (event type, state key) to state event. + + Returns: + The name of the rule (either "direct", "restricted" or "unrestricted") if found, + else None. + """ + access_rules = state_events.get((ACCESS_RULES_TYPE, "")) + if access_rules is None: + return AccessRules.RESTRICTED + + return access_rules.content.get("rule") + + @staticmethod + def _get_join_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]: + """Extract the room's join rule from the given set of state events. + + Args: + state_events (dict[tuple[event type, state key], EventBase]): The set of state + events. + + Returns: + The name of the join rule (either "public", or "invite") if found, else None. + """ + join_rule_event = state_events.get((EventTypes.JoinRules, "")) + if join_rule_event is None: + return None + + return join_rule_event.content.get("join_rule") + + @staticmethod + def _get_members_and_tokens_from_state( + state_events: StateMap[EventBase], + ) -> Tuple[List[str], List[str]]: + """Retrieves the list of users that have a m.room.member event in the room, + as well as 3PID invites tokens in the room. + + Args: + state_events: A dict mapping (event type, state key) to state event. + + Returns: + A tuple containing the: + * targets of the m.room.member events in the state. + * 3PID invite tokens in the state. + """ + existing_members = [] + threepid_invite_tokens = [] + for key, state_event in state_events.items(): + if key[0] == EventTypes.Member and state_event.content: + existing_members.append(state_event.state_key) + if key[0] == EventTypes.ThirdPartyInvite and state_event.content: + # Don't include revoked invites. + threepid_invite_tokens.append(state_event.state_key) + + return existing_members, threepid_invite_tokens + + @staticmethod + def _is_invite_from_threepid(invite: EventBase, threepid_invite_token: str) -> bool: + """Checks whether the given invite follows the given 3PID invite. + + Args: + invite: The m.room.member event with "invite" membership. + threepid_invite_token: The state key from the 3PID invite. + + Returns: + Whether the invite is due to the given 3PID invite. + """ + token = ( + invite.content.get("third_party_invite", {}) + .get("signed", {}) + .get("token", "") + ) + + return token == threepid_invite_token + + def _is_local_user(self, user_id: str) -> bool: + """Checks whether a given user ID belongs to this homeserver, or a remote + + Args: + user_id: A user ID to check. + + Returns: + True if the user belongs to this homeserver, False otherwise. + """ + user = UserID.from_string(user_id) + + # Extract the localpart and ask the module API for a user ID from the localpart + # The module API will append the local homeserver's server_name + local_user_id = self.module_api.get_qualified_user_id(user.localpart) + + # If the user ID we get based on the localpart is the same as the original user ID, + # then they were a local user + return user_id == local_user_id + + def _user_is_invited_to_room( + self, user_id: str, state_events: StateMap[EventBase] + ) -> bool: + """Checks whether a given user has been invited to a room + + A user has an invite for a room if its state contains a `m.room.member` + event with membership "invite" and their user ID as the state key. + + Args: + user_id: The user to check. + state_events: The state events from the room. + + Returns: + True if the user has been invited to the room, or False if they haven't. + """ + for (event_type, state_key), state_event in state_events.items(): + if ( + event_type == EventTypes.Member + and state_key == user_id + and state_event.membership == Membership.INVITE + ): + return True + + return False diff --git a/synapse/types.py b/synapse/types.py
index eafe729dfe..9629b26b01 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import ( import attr from signedjson.key import decode_verify_key_bytes +from six.moves import filter from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError @@ -335,6 +336,19 @@ def contains_invalid_mxid_characters(localpart: str) -> bool: return any(c not in mxid_localpart_allowed_characters for c in localpart) +def strip_invalid_mxid_characters(localpart): + """Removes any invalid characters from an mxid + + Args: + localpart (basestring): the localpart to be stripped + + Returns: + localpart (basestring): the localpart having been stripped + """ + filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart) + return "".join(filtered) + + UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") # the following is a pattern which matches '=', and bytes which are not allowed in a mxid diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 1ee61851e4..09b094ded7 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py
@@ -49,7 +49,8 @@ def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: module = importlib.import_module(module) provider_class = getattr(module, clz) - module_config = provider.get("config") + # Load the module config. If None, pass an empty dictionary instead + module_config = provider.get("config") or {} try: provider_config = provider_class.parse_config(module_config) except jsonschema.ValidationError as e: diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 43c2e0ac23..63f955acff 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py
@@ -19,8 +19,8 @@ import re logger = logging.getLogger(__name__) -def check_3pid_allowed(hs, medium, address): - """Checks whether a given format of 3PID is allowed to be used on this HS +async def check_3pid_allowed(hs, medium, address): + """Checks whether a given 3PID is allowed to be used on this HS Args: hs (synapse.server.HomeServer): server @@ -31,6 +31,33 @@ def check_3pid_allowed(hs, medium, address): bool: whether the 3PID medium/address is allowed to be added to this HS """ + if hs.config.check_is_for_allowed_local_3pids: + data = await hs.get_simple_http_client().get_json( + "https://%s%s" + % ( + hs.config.check_is_for_allowed_local_3pids, + "/_matrix/identity/api/v1/internal-info", + ), + {"medium": medium, "address": address}, + ) + + # Check for invalid response + if "hs" not in data and "shadow_hs" not in data: + return False + + # Check if this user is intended to register for this homeserver + if ( + data.get("hs") != hs.config.server_name + and data.get("shadow_hs") != hs.config.server_name + ): + return False + + if data.get("requires_invite", False) and not data.get("invited", False): + # Requires an invite but hasn't been invited + return False + + return True + if hs.config.allowed_local_3pids: for constraint in hs.config.allowed_local_3pids: logger.debug(