summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py33
-rw-r--r--synapse/api/constants.py8
-rw-r--r--synapse/api/errors.py24
-rw-r--r--synapse/api/presence.py3
-rw-r--r--synapse/api/room_versions.py27
-rw-r--r--synapse/app/_base.py14
-rw-r--r--synapse/app/generic_worker.py5
-rw-r--r--synapse/appservice/__init__.py5
-rw-r--r--synapse/appservice/api.py18
-rw-r--r--synapse/appservice/scheduler.py2
-rw-r--r--synapse/config/_base.py15
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/account_validity.py149
-rw-r--r--synapse/config/api.py23
-rw-r--r--synapse/config/auth.py5
-rw-r--r--synapse/config/cas.py32
-rw-r--r--synapse/config/database.py3
-rw-r--r--synapse/config/emailconfig.py8
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/homeserver.py3
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/config/oidc_config.py23
-rw-r--r--synapse/config/ratelimiting.py9
-rw-r--r--synapse/config/registration.py259
-rw-r--r--synapse/config/repository.py51
-rw-r--r--synapse/config/room_directory.py2
-rw-r--r--synapse/config/saml2_config.py25
-rw-r--r--synapse/config/server.py188
-rw-r--r--synapse/config/sso.py22
-rw-r--r--synapse/config/user_directory.py18
-rw-r--r--synapse/config/workers.py19
-rw-r--r--synapse/event_auth.py34
-rw-r--r--synapse/events/builder.py4
-rw-r--r--synapse/events/snapshot.py3
-rw-r--r--synapse/events/spamcheck.py117
-rw-r--r--synapse/events/third_party_rules.py3
-rw-r--r--synapse/events/utils.py21
-rw-r--r--synapse/federation/federation_client.py66
-rw-r--r--synapse/federation/federation_server.py92
-rw-r--r--synapse/federation/persistence.py6
-rw-r--r--synapse/federation/send_queue.py8
-rw-r--r--synapse/federation/sender/__init__.py10
-rw-r--r--synapse/federation/sender/per_destination_queue.py12
-rw-r--r--synapse/federation/sender/transaction_manager.py5
-rw-r--r--synapse/federation/transport/client.py155
-rw-r--r--synapse/federation/transport/server.py193
-rw-r--r--synapse/federation/units.py8
-rw-r--r--synapse/groups/attestations.py49
-rw-r--r--synapse/groups/groups_server.py312
-rw-r--r--synapse/handlers/account_validity.py127
-rw-r--r--synapse/handlers/admin.py6
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/auth.py50
-rw-r--r--synapse/handlers/cas_handler.py56
-rw-r--r--synapse/handlers/deactivate_account.py11
-rw-r--r--synapse/handlers/device.py30
-rw-r--r--synapse/handlers/devicemessage.py7
-rw-r--r--synapse/handlers/e2e_keys.py24
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/federation.py321
-rw-r--r--synapse/handlers/groups_local.py24
-rw-r--r--synapse/handlers/identity.py277
-rw-r--r--synapse/handlers/initial_sync.py12
-rw-r--r--synapse/handlers/message.py51
-rw-r--r--synapse/handlers/oidc_handler.py233
-rw-r--r--synapse/handlers/pagination.py14
-rw-r--r--synapse/handlers/presence.py33
-rw-r--r--synapse/handlers/profile.py216
-rw-r--r--synapse/handlers/receipts.py9
-rw-r--r--synapse/handlers/register.py152
-rw-r--r--synapse/handlers/room.py90
-rw-r--r--synapse/handlers/room_list.py1
-rw-r--r--synapse/handlers/room_member.py260
-rw-r--r--synapse/handlers/room_member_worker.py56
-rw-r--r--synapse/handlers/saml_handler.py32
-rw-r--r--synapse/handlers/set_password.py3
-rw-r--r--synapse/handlers/sso.py112
-rw-r--r--synapse/handlers/stats.py9
-rw-r--r--synapse/handlers/sync.py120
-rw-r--r--synapse/handlers/typing.py9
-rw-r--r--synapse/handlers/user_directory.py13
-rw-r--r--synapse/http/__init__.py3
-rw-r--r--synapse/http/client.py43
-rw-r--r--synapse/http/connectproxyclient.py96
-rw-r--r--synapse/http/federation/matrix_federation_agent.py13
-rw-r--r--synapse/http/federation/well_known_resolver.py3
-rw-r--r--synapse/http/matrixfederationclient.py20
-rw-r--r--synapse/http/proxyagent.py117
-rw-r--r--synapse/http/request_metrics.py3
-rw-r--r--synapse/http/server.py72
-rw-r--r--synapse/http/servlet.py69
-rw-r--r--synapse/http/site.py9
-rw-r--r--synapse/logging/_remote.py4
-rw-r--r--synapse/logging/_structured.py9
-rw-r--r--synapse/logging/context.py12
-rw-r--r--synapse/logging/opentracing.py8
-rw-r--r--synapse/logging/utils.py3
-rw-r--r--synapse/metrics/__init__.py7
-rw-r--r--synapse/metrics/_exposition.py2
-rw-r--r--synapse/metrics/background_process_metrics.py9
-rw-r--r--synapse/module_api/__init__.py9
-rw-r--r--synapse/notifier.py30
-rw-r--r--synapse/push/baserules.py12
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py9
-rw-r--r--synapse/push/emailpusher.py26
-rw-r--r--synapse/push/httppusher.py23
-rw-r--r--synapse/push/mailer.py213
-rw-r--r--synapse/push/pusherpool.py15
-rw-r--r--synapse/python_dependencies.py10
-rw-r--r--synapse/replication/http/_base.py5
-rw-r--r--synapse/replication/http/account_data.py6
-rw-r--r--synapse/replication/http/membership.py136
-rw-r--r--synapse/replication/http/register.py6
-rw-r--r--synapse/replication/tcp/commands.py3
-rw-r--r--synapse/replication/tcp/external_cache.py10
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--synapse/replication/tcp/protocol.py27
-rw-r--r--synapse/replication/tcp/redis.py32
-rw-r--r--synapse/replication/tcp/resource.py6
-rw-r--r--synapse/replication/tcp/streams/_base.py26
-rw-r--r--synapse/replication/tcp/streams/events.py3
-rw-r--r--synapse/res/templates/account_previously_renewed.html1
-rw-r--r--synapse/res/templates/account_renewed.html2
-rw-r--r--synapse/res/templates/sso.css53
-rw-r--r--synapse/res/templates/sso_account_deactivated.html1
-rw-r--r--synapse/res/templates/sso_auth_account_details.html61
-rw-r--r--synapse/res/templates/sso_auth_bad_user.html3
-rw-r--r--synapse/res/templates/sso_auth_confirm.html3
-rw-r--r--synapse/res/templates/sso_auth_success.html1
-rw-r--r--synapse/res/templates/sso_error.html1
-rw-r--r--synapse/res/templates/sso_footer.html19
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html74
-rw-r--r--synapse/res/templates/sso_new_user_consent.html15
-rw-r--r--synapse/res/templates/sso_partial_profile.html19
-rw-r--r--synapse/res/templates/sso_redirect_confirm.html42
-rw-r--r--synapse/rest/__init__.py3
-rw-r--r--synapse/rest/admin/__init__.py2
-rw-r--r--synapse/rest/admin/groups.py3
-rw-r--r--synapse/rest/admin/media.py9
-rw-r--r--synapse/rest/admin/rooms.py83
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/v1/login.py8
-rw-r--r--synapse/rest/client/v1/presence.py4
-rw-r--r--synapse/rest/client/v1/profile.py45
-rw-r--r--synapse/rest/client/v1/pusher.py4
-rw-r--r--synapse/rest/client/v1/room.py23
-rw-r--r--synapse/rest/client/v2_alpha/account.py231
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py14
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py30
-rw-r--r--synapse/rest/client/v2_alpha/devices.py22
-rw-r--r--synapse/rest/client/v2_alpha/groups.py346
-rw-r--r--synapse/rest/client/v2_alpha/keys.py5
-rw-r--r--synapse/rest/client/v2_alpha/knock.py103
-rw-r--r--synapse/rest/client/v2_alpha/register.py259
-rw-r--r--synapse/rest/client/v2_alpha/relations.py8
-rw-r--r--synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py2
-rw-r--r--synapse/rest/client/v2_alpha/sync.py81
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py142
-rw-r--r--synapse/rest/client/versions.py5
-rw-r--r--synapse/rest/media/v1/_base.py2
-rw-r--r--synapse/rest/media/v1/download_resource.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py34
-rw-r--r--synapse/rest/media/v1/media_storage.py61
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py77
-rw-r--r--synapse/rest/media/v1/upload_resource.py12
-rw-r--r--synapse/rest/synapse/client/oidc/callback_resource.py13
-rw-r--r--synapse/rulecheck/__init__.py0
-rw-r--r--synapse/rulecheck/domain_rule_checker.py181
-rw-r--r--synapse/server.py26
-rw-r--r--synapse/server_notices/resource_limits_server_notices.py2
-rw-r--r--synapse/state/__init__.py16
-rw-r--r--synapse/state/v1.py14
-rw-r--r--synapse/state/v2.py6
-rw-r--r--synapse/storage/__init__.py3
-rw-r--r--synapse/storage/background_updates.py8
-rw-r--r--synapse/storage/database.py41
-rw-r--r--synapse/storage/databases/__init__.py5
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/appservice.py3
-rw-r--r--synapse/storage/databases/main/client_ips.py12
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py38
-rw-r--r--synapse/storage/databases/main/directory.py7
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py11
-rw-r--r--synapse/storage/databases/main/event_federation.py19
-rw-r--r--synapse/storage/databases/main/event_push_actions.py18
-rw-r--r--synapse/storage/databases/main/events.py45
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py29
-rw-r--r--synapse/storage/databases/main/events_forward_extremities.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py18
-rw-r--r--synapse/storage/databases/main/group_server.py40
-rw-r--r--synapse/storage/databases/main/keys.py7
-rw-r--r--synapse/storage/databases/main/media_repository.py15
-rw-r--r--synapse/storage/databases/main/metrics.py4
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/profile.py157
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py7
-rw-r--r--synapse/storage/databases/main/receipts.py13
-rw-r--r--synapse/storage/databases/main/registration.py145
-rw-r--r--synapse/storage/databases/main/room.py62
-rw-r--r--synapse/storage/databases/main/roommember.py17
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py3
-rw-r--r--synapse/storage/databases/main/schema/delta/48/profiles_batch.sql36
-rw-r--r--synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql23
-rw-r--r--synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql17
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres15
-rw-r--r--synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite14
-rw-r--r--synapse/storage/databases/main/state.py11
-rw-r--r--synapse/storage/databases/main/state_deltas.py4
-rw-r--r--synapse/storage/databases/main/stats.py5
-rw-r--r--synapse/storage/databases/main/stream.py42
-rw-r--r--synapse/storage/databases/main/transactions.py21
-rw-r--r--synapse/storage/databases/main/ui_auth.py22
-rw-r--r--synapse/storage/databases/main/user_directory.py73
-rw-r--r--synapse/storage/databases/state/bg_updates.py2
-rw-r--r--synapse/storage/databases/state/store.py6
-rw-r--r--synapse/storage/engines/__init__.py8
-rw-r--r--synapse/storage/engines/_base.py6
-rw-r--r--synapse/storage/engines/postgres.py3
-rw-r--r--synapse/storage/engines/sqlite.py14
-rw-r--r--synapse/storage/persist_events.py12
-rw-r--r--synapse/storage/prepare_database.py14
-rw-r--r--synapse/storage/purge_events.py6
-rw-r--r--synapse/storage/state.py5
-rw-r--r--synapse/storage/types.py37
-rw-r--r--synapse/storage/util/id_generators.py24
-rw-r--r--synapse/storage/util/sequence.py11
-rw-r--r--synapse/third_party_rules/__init__.py14
-rw-r--r--synapse/third_party_rules/access_rules.py971
-rw-r--r--synapse/types.py25
-rw-r--r--synapse/util/async_helpers.py15
-rw-r--r--synapse/util/caches/__init__.py6
-rw-r--r--synapse/util/caches/cached_call.py129
-rw-r--r--synapse/util/caches/descriptors.py17
-rw-r--r--synapse/util/caches/stream_change_cache.py6
-rw-r--r--synapse/util/distributor.py5
-rw-r--r--synapse/util/file_consumer.py15
-rw-r--r--synapse/util/iterutils.py3
-rw-r--r--synapse/util/jsonobject.py6
-rw-r--r--synapse/util/metrics.py3
-rw-r--r--synapse/util/module_loader.py2
-rw-r--r--synapse/util/patch_inline_callbacks.py17
-rw-r--r--synapse/util/stringutils.py33
-rw-r--r--synapse/util/threepids.py31
-rw-r--r--synapse/visibility.py3
248 files changed, 8126 insertions, 2177 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py

index 67ecbd32ff..dbf3799d2e 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -79,7 +79,7 @@ class Auth: self._auth_blocking = AuthBlocking(self.hs) - self._account_validity = hs.config.account_validity + self._account_validity_enabled = hs.config.account_validity_enabled self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._macaroon_secret_key = hs.config.macaroon_secret_key @@ -168,7 +168,7 @@ class Auth: rights: str = "access", allow_expired: bool = False, ) -> synapse.types.Requester: - """ Get a registered user's ID. + """Get a registered user's ID. Args: request: An HTTP request with an access_token query parameter. @@ -192,7 +192,7 @@ class Auth: access_token = self.get_access_token_from_request(request) - user_id, app_service = await self._get_appservice_user_id(request) + user_id, app_service = self._get_appservice_user_id(request) if user_id: if ip_addr and self._track_appservice_user_ips: await self.store.insert_client_ip( @@ -222,7 +222,7 @@ class Auth: shadow_banned = user_info.shadow_banned # Deny the request if the user account has expired. - if self._account_validity.enabled and not allow_expired: + if self._account_validity_enabled and not allow_expired: if await self.store.is_account_expired( user_info.user_id, self.clock.time_msec() ): @@ -268,10 +268,11 @@ class Auth: except KeyError: raise MissingClientTokenError() - async def _get_appservice_user_id(self, request): + def _get_appservice_user_id(self, request): app_service = self.store.get_app_service_by_token( self.get_access_token_from_request(request) ) + if app_service is None: return None, None @@ -289,14 +290,21 @@ class Auth: if not app_service.is_interested_in_user(user_id): raise AuthError(403, "Application service cannot masquerade as this user.") - if not (await self.store.get_user_by_id(user_id)): - raise AuthError(403, "Application service has not registered this user") + # Let ASes manipulate nonexistent users (e.g. to shadow-register them) + # if not (yield self.store.get_user_by_id(user_id)): + # raise AuthError( + # 403, + # "Application service has not registered this user" + # ) return user_id, app_service async def get_user_by_access_token( - self, token: str, rights: str = "access", allow_expired: bool = False, + self, + token: str, + rights: str = "access", + allow_expired: bool = False, ) -> TokenLookupResult: - """ Validate access token and get user_id from it + """Validate access token and get user_id from it Args: token: The access token to get the user by @@ -489,7 +497,7 @@ class Auth: return service async def is_server_admin(self, user: UserID) -> bool: - """ Check if the given user is a local server admin. + """Check if the given user is a local server admin. Args: user: user to check @@ -500,7 +508,10 @@ class Auth: return await self.store.is_server_admin(user) def compute_auth_events( - self, event, current_state_ids: StateMap[str], for_verification: bool = False, + self, + event, + current_state_ids: StateMap[str], + for_verification: bool = False, ) -> List[str]: """Given an event and current state return the list of event IDs used to auth an event. diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 565a8cd76a..af8d59cf87 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py
@@ -27,6 +27,11 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 +# The maximum length for a group id is 255 characters +MAX_GROUPID_LENGTH = 255 +MAX_GROUP_CATEGORYID_LENGTH = 255 +MAX_GROUP_ROLEID_LENGTH = 255 + class Membership: @@ -128,8 +133,7 @@ class UserTypes: class RelationTypes: - """The types of relations known to this server. - """ + """The types of relations known to this server.""" ANNOTATION = "m.annotation" REPLACE = "m.replace" diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cd6670d0a2..a71e518f90 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -390,8 +391,7 @@ class InvalidCaptchaError(SynapseError): class LimitExceededError(SynapseError): - """A client has sent too many requests and is being throttled. - """ + """A client has sent too many requests and is being throttled.""" def __init__( self, @@ -408,8 +408,7 @@ class LimitExceededError(SynapseError): class RoomKeysVersionError(SynapseError): - """A client has tried to upload to a non-current version of the room_keys store - """ + """A client has tried to upload to a non-current version of the room_keys store""" def __init__(self, current_version: str): """ @@ -426,7 +425,9 @@ class UnsupportedRoomVersionError(SynapseError): def __init__(self, msg: str = "Homeserver does not support this room version"): super().__init__( - code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION, + code=400, + msg=msg, + errcode=Codes.UNSUPPORTED_ROOM_VERSION, ) @@ -461,8 +462,7 @@ class IncompatibleRoomVersionError(SynapseError): class PasswordRefusedError(SynapseError): - """A password has been refused, either during password reset/change or registration. - """ + """A password has been refused, either during password reset/change or registration.""" def __init__( self, @@ -470,7 +470,9 @@ class PasswordRefusedError(SynapseError): errcode: str = Codes.WEAK_PASSWORD, ): super().__init__( - code=400, msg=msg, errcode=errcode, + code=400, + msg=msg, + errcode=errcode, ) @@ -493,7 +495,7 @@ class RequestSendFailed(RuntimeError): def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): - """ Utility method for constructing an error response for client-server + """Utility method for constructing an error response for client-server interactions. Args: @@ -510,7 +512,7 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs): class FederationError(RuntimeError): - """ This class is used to inform remote homeservers about erroneous + """This class is used to inform remote homeservers about erroneous PDUs they sent us. FATAL: The remote server could not interpret the source event. diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index 18a462f0ee..b9a8e29460 100644 --- a/synapse/api/presence.py +++ b/synapse/api/presence.py
@@ -56,8 +56,7 @@ class UserPresenceState( @classmethod def default(cls, user_id): - """Returns a default presence state. - """ + """Returns a default presence state.""" return cls( user_id=user_id, state=PresenceState.OFFLINE, diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index de2cc15d33..139fbf5524 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py
@@ -57,7 +57,7 @@ class RoomVersion: state_res = attr.ib(type=int) # one of the StateResolutionVersions enforce_key_validity = attr.ib(type=bool) - # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules + # Before MSC2432, m.room.aliases had special auth rules and redaction rules special_case_aliases_auth = attr.ib(type=bool) # Strictly enforce canonicaljson, do not allow: # * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] @@ -69,6 +69,11 @@ class RoomVersion: limit_notifications_power_levels = attr.ib(type=bool) # MSC2174/MSC2176: Apply updated redaction rules algorithm. msc2176_redaction_rules = attr.ib(type=bool) + # MSC2174/MSC2176: Apply updated redaction rules algorithm. + msc2176_redaction_rules = attr.ib(type=bool) + # MSC2403: Allows join_rules to be set to 'knock', changes auth rules to allow sending + # m.room.membership event with membership 'knock'. + allow_knocking = attr.ib(type=bool) class RoomVersions: @@ -82,6 +87,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V2 = RoomVersion( "2", @@ -93,6 +99,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V3 = RoomVersion( "3", @@ -104,6 +111,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V4 = RoomVersion( "4", @@ -115,6 +123,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V5 = RoomVersion( "5", @@ -126,6 +135,7 @@ class RoomVersions: strict_canonicaljson=False, limit_notifications_power_levels=False, msc2176_redaction_rules=False, + allow_knocking=False, ) V6 = RoomVersion( "6", @@ -137,6 +147,19 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=False, + allow_knocking=False, + ) + V7 = RoomVersion( + "7", + RoomDisposition.UNSTABLE, + EventFormatVersions.V3, + StateResolutionVersions.V2, + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + msc2176_redaction_rules=False, + allow_knocking=True, ) MSC2176 = RoomVersion( "org.matrix.msc2176", @@ -148,6 +171,7 @@ class RoomVersions: strict_canonicaljson=True, limit_notifications_power_levels=True, msc2176_redaction_rules=True, + allow_knocking=False, ) @@ -160,6 +184,7 @@ KNOWN_ROOM_VERSIONS = { RoomVersions.V4, RoomVersions.V5, RoomVersions.V6, + RoomVersions.V7, RoomVersions.MSC2176, ) } # type: Dict[str, RoomVersion] diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 9840a9d55b..43b1f1e94b 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py
@@ -58,7 +58,7 @@ def register_sighup(func, *args, **kwargs): def start_worker_reactor(appname, config, run_command=reactor.run): - """ Run the reactor in the main process + """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor. Pulls configuration from the 'worker' settings in 'config'. @@ -93,7 +93,7 @@ def start_reactor( logger, run_command=reactor.run, ): - """ Run the reactor in the main process + """Run the reactor in the main process Daemonizes if necessary, and then configures some resources, before starting the reactor @@ -313,9 +313,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon refresh_certificate(hs) # Start the tracer - synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa - hs - ) + synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa # It is now safe to start your Synapse. hs.start_listening(listeners) @@ -370,8 +368,7 @@ def setup_sentry(hs): def setup_sdnotify(hs): - """Adds process state hooks to tell systemd what we are up to. - """ + """Adds process state hooks to tell systemd what we are up to.""" # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. @@ -405,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100): class _LimitedHostnameResolver: - """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups. - """ + """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.""" def __init__(self, resolver, max_dns_requests_in_flight): self._resolver = resolver diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 516f2464b4..fe0178dd79 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py
@@ -165,7 +165,7 @@ class PresenceStatusStubServlet(RestServlet): async def on_GET(self, request, user_id): await self.auth.get_user_by_req(request) - return 200, {"presence": "offline"} + return 200, {"presence": "offline", "user_id": user_id} async def on_PUT(self, request, user_id): await self.auth.get_user_by_req(request) @@ -421,8 +421,7 @@ class GenericWorkerPresence(BasePresenceHandler): ] async def set_state(self, target_user, state, ignore_status_msg=False): - """Set the presence state of the user. - """ + """Set the presence state of the user.""" presence = state["presence"] valid_presence = ( diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 3944780a42..0bfc5e445f 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py
@@ -166,7 +166,10 @@ class ApplicationService: @cached(num_args=1, cache_context=True) async def matches_user_in_member_list( - self, room_id: str, store: "DataStore", cache_context: _CacheContext, + self, + room_id: str, + store: "DataStore", + cache_context: _CacheContext, ) -> bool: """Check if this service is interested a room based upon it's membership diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index e366a982b8..09104488f8 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple from prometheus_client import Counter -from synapse.api.constants import EventTypes, ThirdPartyEntityKind +from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind from synapse.api.errors import CodeMessageException from synapse.events import EventBase from synapse.events.utils import serialize_event @@ -76,9 +76,6 @@ def _is_valid_3pe_result(r, field): fields = r["fields"] if not isinstance(fields, dict): return False - for k in fields.keys(): - if not isinstance(fields[k], str): - return False return True @@ -230,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient): try: await self.put_json( - uri=uri, json_body=body, args={"access_token": service.hs_token}, + uri=uri, + json_body=body, + args={"access_token": service.hs_token}, ) sent_transactions_counter.labels(service.id).inc() sent_events_counter.labels(service.id).inc(len(events)) @@ -249,9 +248,14 @@ class ApplicationServiceApi(SimpleHttpClient): e, time_now, as_client_event=True, - is_invite=( + # If this is an invite or a knock membership event, and we're interested + # in this user, then include any stripped state alongside the event. + include_stripped_room_state=( e.type == EventTypes.Member - and e.membership == "invite" + and ( + e.membership == Membership.INVITE + or e.membership == Membership.KNOCK + ) and service.is_interested_in_user(e.state_key) ), ) diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 58291afc22..366c476f80 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py
@@ -68,7 +68,7 @@ MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100 class ApplicationServiceScheduler: - """ Public facing API for this module. Does the required DI to tie the + """Public facing API for this module. Does the required DI to tie the components together. This also serves as the "event_pool", which in this case is a simple array. """ diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index a851f8801d..40af1979c4 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -20,6 +20,7 @@ import errno import os from collections import OrderedDict from hashlib import sha256 +from io import open as io_open from textwrap import dedent from typing import Any, Iterable, List, MutableMapping, Optional @@ -200,7 +201,7 @@ class Config: @classmethod def read_file(cls, file_path, config_name): cls.check_file(file_path, config_name) - with open(file_path) as file_stream: + with io_open(file_path, encoding="utf-8") as file_stream: return file_stream.read() def read_template(self, filename: str) -> jinja2.Template: @@ -224,7 +225,9 @@ class Config: return self.read_templates([filename])[0] def read_templates( - self, filenames: List[str], custom_template_directory: Optional[str] = None, + self, + filenames: List[str], + custom_template_directory: Optional[str] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -264,7 +267,10 @@ class Config: # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) + env = jinja2.Environment( + loader=loader, + autoescape=jinja2.select_autoescape(), + ) # Update the environment with our custom filters env.filters.update( @@ -825,8 +831,7 @@ class ShardedWorkerHandlingConfig: instances = attr.ib(type=List[str]) def should_handle(self, instance_name: str, key: str) -> bool: - """Whether this instance is responsible for handling the given key. - """ + """Whether this instance is responsible for handling the given key.""" # If multiple instances are not defined we always return true if not self.instances or len(self.instances) == 1: return True diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 70025b5d60..0565418e60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -1,6 +1,7 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( + account_validity, api, appservice, auth, @@ -59,6 +60,7 @@ class RootConfig: captcha: captcha.CaptchaConfig voip: voip.VoipConfig registration: registration.RegistrationConfig + account_validity: account_validity.AccountValidityConfig metrics: metrics.MetricsConfig api: api.ApiConfig appservice: appservice.AppServiceConfig diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py new file mode 100644
index 0000000000..6d107944a3 --- /dev/null +++ b/synapse/config/account_validity.py
@@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.config._base import Config, ConfigError + + +class AccountValidityConfig(Config): + section = "account_validity" + + def read_config(self, config, **kwargs): + account_validity_config = config.get("account_validity") or {} + self.account_validity_enabled = account_validity_config.get("enabled", False) + self.account_validity_renew_by_email_enabled = ( + "renew_at" in account_validity_config + ) + + if self.account_validity_enabled: + if "period" in account_validity_config: + self.account_validity_period = self.parse_duration( + account_validity_config["period"] + ) + else: + raise ConfigError("'period' is required when using account validity") + + if "renew_at" in account_validity_config: + self.account_validity_renew_at = self.parse_duration( + account_validity_config["renew_at"] + ) + + if "renew_email_subject" in account_validity_config: + self.account_validity_renew_email_subject = account_validity_config[ + "renew_email_subject" + ] + else: + self.account_validity_renew_email_subject = "Renew your %(app)s account" + + self.account_validity_startup_job_max_delta = ( + self.account_validity_period * 10.0 / 100.0 + ) + + if self.account_validity_renew_by_email_enabled: + if not self.public_baseurl: + raise ConfigError("Can't send renewal emails without 'public_baseurl'") + + # Load account validity templates. + # We do this here instead of in AccountValidityConfig as read_templates + # relies on state that hasn't been initialised in AccountValidityConfig + account_renewed_template_filename = account_validity_config.get( + "account_renewed_html_path", "account_renewed.html" + ) + account_previously_renewed_template_filename = account_validity_config.get( + "account_previously_renewed_html_path", "account_previously_renewed.html" + ) + invalid_token_template_filename = account_validity_config.get( + "invalid_token_html_path", "invalid_token.html" + ) + custom_template_directory = account_validity_config.get("template_dir") + + ( + self.account_validity_account_renewed_template, + self.account_validity_account_previously_renewed_template, + self.account_validity_invalid_token_template, + ) = self.read_templates( + [ + account_renewed_template_filename, + account_previously_renewed_template_filename, + invalid_token_template_filename, + ], + custom_template_directory=custom_template_directory, + ) + + def generate_config_section(self, **kwargs): + return """\ + ## Account Validity ## + # + # Optional account validity configuration. This allows for accounts to be denied + # any request after a given period. + # + # Once this feature is enabled, Synapse will look for registered users without an + # expiration date at startup and will add one to every account it found using the + # current settings at that time. + # This means that, if a validity period is set, and Synapse is restarted (it will + # then derive an expiration date from the current validity period), and some time + # after that the validity period changes and Synapse is restarted, the users' + # expiration dates won't be updated unless their account is manually renewed. This + # date will be randomly selected within a range [now + period - d ; now + period], + # where d is equal to 10% of the validity period. + # + account_validity: + # The account validity feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # The period after which an account is valid after its registration. When + # renewing the account, its validity period will be extended by this amount + # of time. This parameter is required when using the account validity + # feature. + # + #period: 6w + + # The amount of time before an account's expiry date at which Synapse will + # send an email to the account's email address with a renewal link. By + # default, no such emails are sent. + # + # If you enable this setting, you will also need to fill out the 'email' and + # 'public_baseurl' configuration sections. + # + #renew_at: 1w + + # The subject of the email sent out with the renewal link. '%(app)s' can be + # used as a placeholder for the 'app_name' parameter from the 'email' + # section. + # + # Note that the placeholder must be written '%(app)s', including the + # trailing 's'. + # + # If this is not set, a default value is used. + # + #renew_email_subject: "Renew your %(app)s account" + + # Directory in which Synapse will try to find templates for the HTML files to + # serve to the user when trying to renew an account. If not set, default + # templates from within the Synapse package will be used. + # + #template_dir: "res/templates" + + # File within 'template_dir' giving the HTML to be displayed to the user after + # they successfully renewed their account. If not set, default text is used. + # + #account_renewed_html_path: "account_renewed.html" + + # File within 'template_dir' giving the HTML to be displayed when the user + # tries to renew an account with an invalid renewal token. If not set, + # default text is used. + # + #invalid_token_html_path: "invalid_token.html" + """ diff --git a/synapse/config/api.py b/synapse/config/api.py
index 74cd53a8ed..0638ed8d2e 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py
@@ -1,4 +1,5 @@ # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,27 +17,31 @@ from synapse.api.constants import EventTypes from ._base import Config +# The default types of room state to send to users to are invited to or knock on a room. +DEFAULT_ROOM_STATE_TYPES = [ + EventTypes.JoinRules, + EventTypes.CanonicalAlias, + EventTypes.RoomAvatar, + EventTypes.RoomEncryption, + EventTypes.Name, +] + class ApiConfig(Config): section = "api" def read_config(self, config, **kwargs): self.room_invite_state_types = config.get( - "room_invite_state_types", - [ - EventTypes.JoinRules, - EventTypes.CanonicalAlias, - EventTypes.RoomAvatar, - EventTypes.RoomEncryption, - EventTypes.Name, - ], + "room_invite_state_types", DEFAULT_ROOM_STATE_TYPES ) def generate_config_section(cls, **kwargs): return """\ ## API Configuration ## - # A list of event types that will be included in the room_invite_state + # A list of event types from a room that will be given to users when they + # are invited to a room. This allows clients to display information about the + # room that they've been invited to, without actually being in the room yet. # #room_invite_state_types: # - "{JoinRules}" diff --git a/synapse/config/auth.py b/synapse/config/auth.py
index 2b3e2ce87b..7fa64b821a 100644 --- a/synapse/config/auth.py +++ b/synapse/config/auth.py
@@ -18,8 +18,7 @@ from ._base import Config class AuthConfig(Config): - """Password and login configuration - """ + """Password and login configuration""" section = "auth" @@ -98,7 +97,7 @@ class AuthConfig(Config): # session to be active. # # This defaults to 0, meaning the user is queried for their credentials - # before every action, but this can be overridden to alow a single + # before every action, but this can be overridden to allow a single # validation to be re-used. This weakens the protections afforded by # the user-interactive authentication process, by allowing for multiple # (and potentially different) operations to use the same validation session. diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index aaa7eba110..dbf5085965 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py
@@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List + +from synapse.config.sso import SsoAttributeRequirement + from ._base import Config, ConfigError +from ._util import validate_config class CasConfig(Config): @@ -40,12 +45,16 @@ class CasConfig(Config): # TODO Update this to a _synapse URL. self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" self.cas_displayname_attribute = cas_config.get("displayname_attribute") - self.cas_required_attributes = cas_config.get("required_attributes") or {} + required_attributes = cas_config.get("required_attributes") or {} + self.cas_required_attributes = _parsed_required_attributes_def( + required_attributes + ) + else: self.cas_server_url = None self.cas_service_url = None self.cas_displayname_attribute = None - self.cas_required_attributes = {} + self.cas_required_attributes = [] def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ @@ -77,3 +86,22 @@ class CasConfig(Config): # userGroup: "staff" # department: None """ + + +# CAS uses a legacy required attributes mapping, not the one provided by +# SsoAttributeRequirement. +REQUIRED_ATTRIBUTES_SCHEMA = { + "type": "object", + "additionalProperties": {"anyOf": [{"type": "string"}, {"type": "null"}]}, +} + + +def _parsed_required_attributes_def( + required_attributes: Any, +) -> List[SsoAttributeRequirement]: + validate_config( + REQUIRED_ATTRIBUTES_SCHEMA, + required_attributes, + config_path=("cas_config", "required_attributes"), + ) + return [SsoAttributeRequirement(k, v) for k, v in required_attributes.items()] diff --git a/synapse/config/database.py b/synapse/config/database.py
index 8a18a9ca2a..e7889b9c20 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py
@@ -207,8 +207,7 @@ class DatabaseConfig(Config): ) def get_single_database(self) -> DatabaseConnectionConfig: - """Returns the database if there is only one, useful for e.g. tests - """ + """Returns the database if there is only one, useful for e.g. tests""" if not self.databases: raise Exception("More than one database exists") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index d4328c46b9..5431691831 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py
@@ -289,7 +289,8 @@ class EmailConfig(Config): self.email_notif_template_html, self.email_notif_template_text, ) = self.read_templates( - [notif_template_html, notif_template_text], template_dir, + [notif_template_html, notif_template_text], + template_dir, ) self.email_notif_for_new_users = email_config.get( @@ -299,7 +300,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity.renew_by_email_enabled: + if self.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) @@ -311,7 +312,8 @@ class EmailConfig(Config): self.account_validity_template_html, self.account_validity_template_text, ) = self.read_templates( - [expiry_template_html, expiry_template_text], template_dir, + [expiry_template_html, expiry_template_text], + template_dir, ) subjects_config = email_config.get("subjects", {}) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index b1c1c51e4d..ba9d37553b 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config._base import Config from synapse.types import JsonDict @@ -25,5 +26,11 @@ class ExperimentalConfig(Config): def read_config(self, config: JsonDict, **kwargs): experimental = config.get("experimental_features") or {} + # MSC2403 (room knocking) + self.msc2403_enabled = experimental.get("msc2403_enabled", False) # type: bool + if self.msc2403_enabled: + # Enable the MSC2403 unstable room version + KNOWN_ROOM_VERSIONS.update({RoomVersions.V7.identifier: RoomVersions.V7}) + # MSC2858 (multiple SSO identity providers) self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 64a2429f77..58961679ff 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from ._base import RootConfig +from .account_validity import AccountValidityConfig from .api import ApiConfig from .appservice import AppServiceConfig from .auth import AuthConfig @@ -69,6 +69,7 @@ class HomeServerConfig(RootConfig): CaptchaConfig, VoipConfig, RegistrationConfig, + AccountValidityConfig, MetricsConfig, ApiConfig, AppServiceConfig, diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 4df3f93c1c..e56cf846f5 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -162,7 +162,10 @@ class LoggingConfig(Config): ) logging_group.add_argument( - "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS, + "-f", + "--log-file", + dest="log_file", + help=argparse.SUPPRESS, ) def generate_files(self, config, config_dir_path): diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 4d0f24a9d5..a27594befc 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py
@@ -201,9 +201,9 @@ class OIDCConfig(Config): # user_mapping_provider: # config: # subject_claim: "id" - # localpart_template: "{{ user.login }}" - # display_name_template: "{{ user.name }}" - # email_template: "{{ user.email }}" + # localpart_template: "{{{{ user.login }}}}" + # display_name_template: "{{{{ user.name }}}}" + # email_template: "{{{{ user.email }}}}" # For use with Keycloak # @@ -230,8 +230,8 @@ class OIDCConfig(Config): # user_mapping_provider: # config: # subject_claim: "id" - # localpart_template: "{{ user.login }}" - # display_name_template: "{{ user.name }}" + # localpart_template: "{{{{ user.login }}}}" + # display_name_template: "{{{{ user.name }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) @@ -355,9 +355,10 @@ def _parse_oidc_config_dict( ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) ump_config.setdefault("config", {}) - (user_mapping_provider_class, user_mapping_provider_config,) = load_module( - ump_config, config_path + ("user_mapping_provider",) - ) + ( + user_mapping_provider_class, + user_mapping_provider_config, + ) = load_module(ump_config, config_path + ("user_mapping_provider",)) # Ensure loaded user mapping module has defined all necessary methods required_methods = [ @@ -372,7 +373,11 @@ def _parse_oidc_config_dict( if missing_methods: raise ConfigError( "Class %s is missing required " - "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), + "methods: %s" + % ( + user_mapping_provider_class, + ", ".join(missing_methods), + ), config_path + ("user_mapping_provider", "module"), ) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index def33a60ad..070eb1b761 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py
@@ -76,6 +76,9 @@ class RatelimitConfig(Config): ) self.rc_registration = RateLimitConfig(config.get("rc_registration", {})) + self.rc_third_party_invite = RateLimitConfig( + config.get("rc_third_party_invite", {}) + ) rc_login_config = config.get("rc_login", {}) self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {})) @@ -138,6 +141,8 @@ class RatelimitConfig(Config): # - one for login that ratelimits login requests based on the account the # client is attempting to log into, based on the amount of failed login # attempts for this account. + # - one that ratelimits third-party invites requests based on the account + # that's making the requests. # - one for ratelimiting redactions by room admins. If this is not explicitly # set then it uses the same ratelimiting as per rc_message. This is useful # to allow room admins to deal with abuse quickly. @@ -170,6 +175,10 @@ class RatelimitConfig(Config): # per_second: 0.17 # burst_count: 3 # + #rc_third_party_invite: + # per_second: 0.2 + # burst_count: 10 + # #rc_admin_redaction: # per_second: 1 # burst_count: 50 diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index eb650af7fb..c2ffbc7c13 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py
@@ -13,74 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os - -import pkg_resources - from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID from synapse.util.stringutils import random_string_with_symbols, strtobool -class AccountValidityConfig(Config): - section = "accountvalidity" - - def __init__(self, config, synapse_config): - if config is None: - return - super().__init__() - self.enabled = config.get("enabled", False) - self.renew_by_email_enabled = "renew_at" in config - - if self.enabled: - if "period" in config: - self.period = self.parse_duration(config["period"]) - else: - raise ConfigError("'period' is required when using account validity") - - if "renew_at" in config: - self.renew_at = self.parse_duration(config["renew_at"]) - - if "renew_email_subject" in config: - self.renew_email_subject = config["renew_email_subject"] - else: - self.renew_email_subject = "Renew your %(app)s account" - - self.startup_job_max_delta = self.period * 10.0 / 100.0 - - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - - template_dir = config.get("template_dir") - - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - if "account_renewed_html_path" in config: - file_path = os.path.join(template_dir, config["account_renewed_html_path"]) - - self.account_renewed_html_content = self.read_file( - file_path, "account_validity.account_renewed_html_path" - ) - else: - self.account_renewed_html_content = ( - "<html><body>Your account has been successfully renewed.</body><html>" - ) - - if "invalid_token_html_path" in config: - file_path = os.path.join(template_dir, config["invalid_token_html_path"]) - - self.invalid_token_html_content = self.read_file( - file_path, "account_validity.invalid_token_html_path" - ) - else: - self.invalid_token_html_content = ( - "<html><body>Invalid renewal token.</body><html>" - ) - - class RegistrationConfig(Config): section = "registration" @@ -93,14 +31,21 @@ class RegistrationConfig(Config): str(config["disable_registration"]) ) - self.account_validity = AccountValidityConfig( - config.get("account_validity") or {}, config - ) - self.registrations_require_3pid = config.get("registrations_require_3pid", []) self.allowed_local_3pids = config.get("allowed_local_3pids", []) + self.check_is_for_allowed_local_3pids = config.get( + "check_is_for_allowed_local_3pids", None + ) + self.allow_invited_3pids = config.get("allow_invited_3pids", False) + + self.disable_3pid_changes = config.get("disable_3pid_changes", False) + self.enable_3pid_lookup = config.get("enable_3pid_lookup", True) self.registration_shared_secret = config.get("registration_shared_secret") + self.register_mxid_from_3pid = config.get("register_mxid_from_3pid") + self.register_just_use_email_for_display_name = config.get( + "register_just_use_email_for_display_name", False + ) self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.trusted_third_party_id_servers = config.get( @@ -108,7 +53,21 @@ class RegistrationConfig(Config): ) account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") + if ( + self.account_threepid_delegate_email + and not self.account_threepid_delegate_email.startswith("http") + ): + raise ConfigError( + "account_threepid_delegates.email must begin with http:// or https://" + ) self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") + if ( + self.account_threepid_delegate_msisdn + and not self.account_threepid_delegate_msisdn.startswith("http") + ): + raise ConfigError( + "account_threepid_delegates.msisdn must begin with http:// or https://" + ) if self.account_threepid_delegate_msisdn and not self.public_baseurl: raise ConfigError( "The configuration option `public_baseurl` is required if " @@ -177,6 +136,15 @@ class RegistrationConfig(Config): self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) self.enable_3pid_changes = config.get("enable_3pid_changes", True) + self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", []) + if not isinstance(self.replicate_user_profiles_to, list): + self.replicate_user_profiles_to = [self.replicate_user_profiles_to] + + self.shadow_server = config.get("shadow_server", None) + self.rewrite_identity_server_urls = ( + config.get("rewrite_identity_server_urls") or {} + ) + self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False ) @@ -189,6 +157,23 @@ class RegistrationConfig(Config): # The success template used during fallback auth. self.fallback_success_template = self.read_template("auth_success.html") + self.bind_new_user_emails_to_sydent = config.get( + "bind_new_user_emails_to_sydent" + ) + + if self.bind_new_user_emails_to_sydent: + if not isinstance( + self.bind_new_user_emails_to_sydent, str + ) or not self.bind_new_user_emails_to_sydent.startswith("http"): + raise ConfigError( + "Option bind_new_user_emails_to_sydent has invalid value" + ) + + # Remove trailing slashes + self.bind_new_user_emails_to_sydent = self.bind_new_user_emails_to_sydent.strip( + "/" + ) + def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: registration_shared_secret = 'registration_shared_secret: "%s"' % ( @@ -208,69 +193,6 @@ class RegistrationConfig(Config): # #enable_registration: false - # Optional account validity configuration. This allows for accounts to be denied - # any request after a given period. - # - # Once this feature is enabled, Synapse will look for registered users without an - # expiration date at startup and will add one to every account it found using the - # current settings at that time. - # This means that, if a validity period is set, and Synapse is restarted (it will - # then derive an expiration date from the current validity period), and some time - # after that the validity period changes and Synapse is restarted, the users' - # expiration dates won't be updated unless their account is manually renewed. This - # date will be randomly selected within a range [now + period - d ; now + period], - # where d is equal to 10%% of the validity period. - # - account_validity: - # The account validity feature is disabled by default. Uncomment the - # following line to enable it. - # - #enabled: true - - # The period after which an account is valid after its registration. When - # renewing the account, its validity period will be extended by this amount - # of time. This parameter is required when using the account validity - # feature. - # - #period: 6w - - # The amount of time before an account's expiry date at which Synapse will - # send an email to the account's email address with a renewal link. By - # default, no such emails are sent. - # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. - # - #renew_at: 1w - - # The subject of the email sent out with the renewal link. '%%(app)s' can be - # used as a placeholder for the 'app_name' parameter from the 'email' - # section. - # - # Note that the placeholder must be written '%%(app)s', including the - # trailing 's'. - # - # If this is not set, a default value is used. - # - #renew_email_subject: "Renew your %%(app)s account" - - # Directory in which Synapse will try to find templates for the HTML files to - # serve to the user when trying to renew an account. If not set, default - # templates from within the Synapse package will be used. - # - #template_dir: "res/templates" - - # File within 'template_dir' giving the HTML to be displayed to the user after - # they successfully renewed their account. If not set, default text is used. - # - #account_renewed_html_path: "account_renewed.html" - - # File within 'template_dir' giving the HTML to be displayed when the user - # tries to renew an account with an invalid renewal token. If not set, - # default text is used. - # - #invalid_token_html_path: "invalid_token.html" - # Time that a user's session remains valid for, after they log in. # # Note that this is not currently compatible with guest logins. @@ -293,9 +215,32 @@ class RegistrationConfig(Config): # #disable_msisdn_registration: true + # Derive the user's matrix ID from a type of 3PID used when registering. + # This overrides any matrix ID the user proposes when calling /register + # The 3PID type should be present in registrations_require_3pid to avoid + # users failing to register if they don't specify the right kind of 3pid. + # + #register_mxid_from_3pid: email + + # Uncomment to set the display name of new users to their email address, + # rather than using the default heuristic. + # + #register_just_use_email_for_display_name: true + # Mandate that users are only allowed to associate certain formats of # 3PIDs with accounts on this server. # + # Use an Identity Server to establish which 3PIDs are allowed to register? + # Overrides allowed_local_3pids below. + # + #check_is_for_allowed_local_3pids: matrix.org + # + # If you are using an IS you can also check whether that IS registers + # pending invites for the given 3PID (and then allow it to sign up on + # the platform): + # + #allow_invited_3pids: false + # #allowed_local_3pids: # - medium: email # pattern: '.*@matrix\\.org' @@ -304,6 +249,11 @@ class RegistrationConfig(Config): # - medium: msisdn # pattern: '\\+44' + # If true, stop users from trying to change the 3PIDs associated with + # their accounts. + # + #disable_3pid_changes: false + # Enable 3PIDs lookup requests to identity servers from this server. # #enable_3pid_lookup: true @@ -335,6 +285,30 @@ class RegistrationConfig(Config): # #default_identity_server: https://matrix.org + # If enabled, user IDs, display names and avatar URLs will be replicated + # to this server whenever they change. + # This is an experimental API currently implemented by sydent to support + # cross-homeserver user directories. + # + #replicate_user_profiles_to: example.com + + # If specified, attempt to replay registrations, profile changes & 3pid + # bindings on the given target homeserver via the AS API. The HS is authed + # via a given AS token. + # + #shadow_server: + # hs_url: https://shadow.example.com + # hs: shadow.example.com + # as_token: 12u394refgbdhivsia + + # If enabled, don't let users set their own display names/avatars + # other than for the very first time (unless they are a server admin). + # Useful when provisioning users based on the contents of a 3rd party + # directory and to avoid ambiguities. + # + #disable_set_displayname: false + #disable_set_avatar_url: false + # Handle threepid (email/phone etc) registration and password resets through a set of # *trusted* identity servers. Note that this allows the configured identity server to # reset passwords for accounts! @@ -391,6 +365,8 @@ class RegistrationConfig(Config): # By default, any room aliases included in this list will be created # as a publicly joinable room when the first user registers for the # homeserver. This behaviour can be customised with the settings below. + # If the room already exists, make certain it is a publicly joinable + # room. The join rule of the room must be set to 'public'. # #auto_join_rooms: # - "#example:example.com" @@ -461,6 +437,31 @@ class RegistrationConfig(Config): # Defaults to true. # #auto_join_rooms_for_guests: false + + # Rewrite identity server URLs with a map from one URL to another. Applies to URLs + # provided by clients (which have https:// prepended) and those specified + # in `account_threepid_delegates`. URLs should not feature a trailing slash. + # + #rewrite_identity_server_urls: + # "https://somewhere.example.com": "https://somewhereelse.example.com" + + # When a user registers an account with an email address, it can be useful to + # bind that email address to their mxid on an identity server. Typically, this + # requires the user to validate their email address with the identity server. + # However if Synapse itself is handling email validation on registration, the + # user ends up needing to validate their email twice, which leads to poor UX. + # + # It is possible to force Sydent, one identity server implementation, to bind + # threepids using its internal, unauthenticated bind API: + # https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api + # + # Configure the address of a Sydent server here to have Synapse attempt + # to automatically bind users' emails following registration. The + # internal bind API must be reachable from Synapse, but should NOT be + # exposed to any third party, as it allows the creation of bindings + # without validation. + # + #bind_new_user_emails_to_sydent: https://example.com:8091 """ % locals() ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 850ac3ebd6..45f90beabc 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py
@@ -17,9 +17,7 @@ import os from collections import namedtuple from typing import Dict, List -from netaddr import IPSet - -from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST +from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module @@ -54,7 +52,7 @@ MediaStorageProviderConfig = namedtuple( def parse_thumbnail_requirements(thumbnail_sizes): - """ Takes a list of dictionaries with "width", "height", and "method" keys + """Takes a list of dictionaries with "width", "height", and "method" keys and creates a map from image media types to the thumbnail size, thumbnailing method, and thumbnail media type to precalculate @@ -107,6 +105,12 @@ class ContentRepositoryConfig(Config): self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M")) self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M")) + self.max_avatar_size = config.get("max_avatar_size") + if self.max_avatar_size: + self.max_avatar_size = self.parse_size(self.max_avatar_size) + + self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", []) + self.media_store_path = self.ensure_directory( config.get("media_store_path", "media_store") ) @@ -187,16 +191,17 @@ class ContentRepositoryConfig(Config): "to work" ) - self.url_preview_ip_range_blacklist = IPSet( - config["url_preview_ip_range_blacklist"] - ) - # we always blacklist '0.0.0.0' and '::', which are supposed to be # unroutable addresses. - self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.url_preview_ip_range_blacklist = generate_ip_set( + config["url_preview_ip_range_blacklist"], + ["0.0.0.0", "::"], + config_path=("url_preview_ip_range_blacklist",), + ) - self.url_preview_ip_range_whitelist = IPSet( - config.get("url_preview_ip_range_whitelist", ()) + self.url_preview_ip_range_whitelist = generate_ip_set( + config.get("url_preview_ip_range_whitelist", ()), + config_path=("url_preview_ip_range_whitelist",), ) self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ()) @@ -250,6 +255,30 @@ class ContentRepositoryConfig(Config): # #max_upload_size: 50M + # The largest allowed size for a user avatar. If not defined, no + # restriction will be imposed. + # + # Note that this only applies when an avatar is changed globally. + # Per-room avatar changes are not affected. See allow_per_room_profiles + # for disabling that functionality. + # + # Note that user avatar changes will not work if this is set without + # using Synapse's local media repo. + # + #max_avatar_size: 10M + + # Allow mimetypes for a user avatar. If not defined, no restriction will + # be imposed. + # + # Note that this only applies when an avatar is changed globally. + # Per-room avatar changes are not affected. See allow_per_room_profiles + # for disabling that functionality. + # + # Note that user avatar changes will not work if this is set without + # using Synapse's local media repo. + # + #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"] + # Maximum number of pixels that will be thumbnailed # #max_image_pixels: 32M diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 9a3e1c3e7d..2dd719c388 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py
@@ -123,7 +123,7 @@ class RoomDirectoryConfig(Config): alias (str) Returns: - boolean: True if user is allowed to crate the alias + boolean: True if user is allowed to create the alias """ for rule in self._alias_creation_rules: if rule.matches(user_id, room_id, [alias]): diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 7226abd829..4b494f217f 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py
@@ -17,8 +17,7 @@ import logging from typing import Any, List -import attr - +from synapse.config.sso import SsoAttributeRequirement from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -398,32 +397,18 @@ class SAML2Config(Config): } -@attr.s(frozen=True) -class SamlAttributeRequirement: - """Object describing a single requirement for SAML attributes.""" - - attribute = attr.ib(type=str) - value = attr.ib(type=str) - - JSON_SCHEMA = { - "type": "object", - "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, - "required": ["attribute", "value"], - } - - ATTRIBUTE_REQUIREMENTS_SCHEMA = { "type": "array", - "items": SamlAttributeRequirement.JSON_SCHEMA, + "items": SsoAttributeRequirement.JSON_SCHEMA, } def _parse_attribute_requirements_def( attribute_requirements: Any, -) -> List[SamlAttributeRequirement]: +) -> List[SsoAttributeRequirement]: validate_config( ATTRIBUTE_REQUIREMENTS_SCHEMA, attribute_requirements, - config_path=["saml2_config", "attribute_requirements"], + config_path=("saml2_config", "attribute_requirements"), ) - return [SamlAttributeRequirement(**x) for x in attribute_requirements] + return [SsoAttributeRequirement(**x) for x in attribute_requirements] diff --git a/synapse/config/server.py b/synapse/config/server.py
index 5d72cf2d82..1b82c81db1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging import os.path import re @@ -23,7 +24,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml -from netaddr import IPSet +from netaddr import AddrFormatError, IPNetwork, IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.util.stringutils import parse_and_validate_server_name @@ -40,6 +41,71 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] + +def _6to4(network: IPNetwork) -> IPNetwork: + """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056.""" + + # 6to4 networks consist of: + # * 2002 as the first 16 bits + # * The first IPv4 address in the network hex-encoded as the next 32 bits + # * The new prefix length needs to include the bits from the 2002 prefix. + hex_network = hex(network.first)[2:] + hex_network = ("0" * (8 - len(hex_network))) + hex_network + return IPNetwork( + "2002:%s:%s::/%d" + % ( + hex_network[:4], + hex_network[4:], + 16 + network.prefixlen, + ) + ) + + +def generate_ip_set( + ip_addresses: Optional[Iterable[str]], + extra_addresses: Optional[Iterable[str]] = None, + config_path: Optional[Iterable[str]] = None, +) -> IPSet: + """ + Generate an IPSet from a list of IP addresses or CIDRs. + + Additionally, for each IPv4 network in the list of IP addresses, also + includes the corresponding IPv6 networks. + + This includes: + + * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1) + * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2) + * 6to4 Address (see RFC 3056, section 2) + + Args: + ip_addresses: An iterable of IP addresses or CIDRs. + extra_addresses: An iterable of IP addresses or CIDRs. + config_path: The path in the configuration for error messages. + + Returns: + A new IP set. + """ + result = IPSet() + for ip in itertools.chain(ip_addresses or (), extra_addresses or ()): + try: + network = IPNetwork(ip) + except AddrFormatError as e: + raise ConfigError( + "Invalid IP range provided: %s." % (ip,), config_path + ) from e + result.add(network) + + # It is possible that these already exist in the set, but that's OK. + if ":" not in str(network): + result.add(IPNetwork(network).ipv6(ipv4_compatible=True)) + result.add(IPNetwork(network).ipv6(ipv4_compatible=False)) + result.add(_6to4(network)) + + return result + + +# IP ranges that are considered private / unroutable / don't make sense. DEFAULT_IP_RANGE_BLACKLIST = [ # Localhost "127.0.0.0/8", @@ -53,6 +119,8 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "192.0.0.0/24", # Link-local networks. "169.254.0.0/16", + # Formerly used for 6to4 relay. + "192.88.99.0/24", # Testing networks. "198.18.0.0/15", "192.0.2.0/24", @@ -66,6 +134,12 @@ DEFAULT_IP_RANGE_BLACKLIST = [ "fe80::/10", # Unique local addresses. "fc00::/7", + # Testing networks. + "2001:db8::/32", + # Multicast. + "ff00::/8", + # Site-local addresses + "fec0::/10", ] DEFAULT_ROOM_VERSION = "6" @@ -185,7 +259,8 @@ class ServerConfig(Config): # Whether to require sharing a room with a user to retrieve their # profile data self.limit_profile_requests_to_users_who_share_rooms = config.get( - "limit_profile_requests_to_users_who_share_rooms", False, + "limit_profile_requests_to_users_who_share_rooms", + False, ) if "restrict_public_rooms_to_local_users" in config and ( @@ -290,17 +365,15 @@ class ServerConfig(Config): ) # Attempt to create an IPSet from the given ranges - try: - self.ip_range_blacklist = IPSet(ip_range_blacklist) - except Exception as e: - raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e + # Always blacklist 0.0.0.0, :: - self.ip_range_blacklist.update(["0.0.0.0", "::"]) + self.ip_range_blacklist = generate_ip_set( + ip_range_blacklist, ["0.0.0.0", "::"], config_path=("ip_range_blacklist",) + ) - try: - self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ())) - except Exception as e: - raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e + self.ip_range_whitelist = generate_ip_set( + config.get("ip_range_whitelist", ()), config_path=("ip_range_whitelist",) + ) # The federation_ip_range_blacklist is used for backwards-compatibility # and only applies to federation and identity servers. If it is not given, @@ -308,14 +381,12 @@ class ServerConfig(Config): federation_ip_range_blacklist = config.get( "federation_ip_range_blacklist", ip_range_blacklist ) - try: - self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist." - ) from e # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.federation_ip_range_blacklist = generate_ip_set( + federation_ip_range_blacklist, + ["0.0.0.0", "::"], + config_path=("federation_ip_range_blacklist",), + ) if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": @@ -337,6 +408,12 @@ class ServerConfig(Config): # events with profile information that differ from the target's global profile. self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) + # Whether to show the users on this homeserver in the user directory. Defaults to + # True. + self.show_users_in_user_directory = config.get( + "show_users_in_user_directory", True + ) + retention_config = config.get("retention") if retention_config is None: retention_config = {} @@ -549,7 +626,9 @@ class ServerConfig(Config): if manhole: self.listeners.append( ListenerConfig( - port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + port=manhole, + bind_addresses=["127.0.0.1"], + type="manhole", ) ) @@ -585,7 +664,8 @@ class ServerConfig(Config): # and letting the client know which email address is bound to an account and # which one isn't. self.request_token_inhibit_3pid_errors = config.get( - "request_token_inhibit_3pid_errors", False, + "request_token_inhibit_3pid_errors", + False, ) # List of users trialing the new experimental default push rules. This setting is @@ -1042,6 +1122,74 @@ class ServerConfig(Config): # #allow_per_room_profiles: false + # Whether to show the users on this homeserver in the user directory. Defaults to + # 'true'. + # + #show_users_in_user_directory: false + + # Message retention policy at the server level. + # + # Room admins and mods can define a retention period for their rooms using the + # 'm.room.retention' state event, and server admins can cap this period by setting + # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options. + # + # If this feature is enabled, Synapse will regularly look for and purge events + # which are older than the room's maximum retention period. Synapse will also + # filter events received over federation so that events that should have been + # purged are ignored and not stored again. + # + retention: + # The message retention policies feature is disabled by default. Uncomment the + # following line to enable it. + # + #enabled: true + + # Default retention policy. If set, Synapse will apply it to rooms that lack the + # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't + # matter much because Synapse doesn't take it into account yet. + # + #default_policy: + # min_lifetime: 1d + # max_lifetime: 1y + + # Retention policy limits. If set, a user won't be able to send a + # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime' + # that's not within this range. This is especially useful in closed federations, + # in which server admins can make sure every federating server applies the same + # rules. + # + #allowed_lifetime_min: 1d + #allowed_lifetime_max: 1y + + # Server admins can define the settings of the background jobs purging the + # events which lifetime has expired under the 'purge_jobs' section. + # + # If no configuration is provided, a single job will be set up to delete expired + # events in every room daily. + # + # Each job's configuration defines which range of message lifetimes the job + # takes care of. For example, if 'shortest_max_lifetime' is '2d' and + # 'longest_max_lifetime' is '3d', the job will handle purging expired events in + # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and + # lower than or equal to 3 days. Both the minimum and the maximum value of a + # range are optional, e.g. a job with no 'shortest_max_lifetime' and a + # 'longest_max_lifetime' of '3d' will handle every room with a retention policy + # which 'max_lifetime' is lower than or equal to three days. + # + # The rationale for this per-job configuration is that some rooms might have a + # retention policy with a low 'max_lifetime', where history needs to be purged + # of outdated messages on a very frequent basis (e.g. every 5min), but not want + # that purge to be performed by a job that's iterating over every room it knows, + # which would be quite heavy on the server. + # + #purge_jobs: + # - shortest_max_lifetime: 1d + # longest_max_lifetime: 3d + # interval: 5m: + # - shortest_max_lifetime: 3d + # longest_max_lifetime: 1y + # interval: 24h + # How long to keep redacted events in unredacted form in the database. After # this period redacted events get replaced with their redacted form in the DB. # diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 19bdfd462b..243cc681e8 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py
@@ -12,14 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Dict, Optional + +import attr from ._base import Config +@attr.s(frozen=True) +class SsoAttributeRequirement: + """Object describing a single requirement for SSO attributes.""" + + attribute = attr.ib(type=str) + # If a value is not given, than the attribute must simply exist. + value = attr.ib(type=Optional[str]) + + JSON_SCHEMA = { + "type": "object", + "properties": {"attribute": {"type": "string"}, "value": {"type": "string"}}, + "required": ["attribute", "value"], + } + + class SSOConfig(Config): - """SSO Configuration - """ + """SSO Configuration""" section = "sso" diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..306e0cc8a4 100644 --- a/synapse/config/user_directory.py +++ b/synapse/config/user_directory.py
@@ -26,6 +26,8 @@ class UserDirectoryConfig(Config): def read_config(self, config, **kwargs): self.user_directory_search_enabled = True self.user_directory_search_all_users = False + self.user_directory_defer_to_id_server = None + self.user_directory_search_prefer_local_users = False user_directory_config = config.get("user_directory", None) if user_directory_config: self.user_directory_search_enabled = user_directory_config.get( @@ -34,6 +36,12 @@ class UserDirectoryConfig(Config): self.user_directory_search_all_users = user_directory_config.get( "search_all_users", False ) + self.user_directory_defer_to_id_server = user_directory_config.get( + "defer_to_id_server", None + ) + self.user_directory_search_prefer_local_users = user_directory_config.get( + "prefer_local_users", False + ) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """ @@ -49,7 +57,17 @@ class UserDirectoryConfig(Config): # rebuild the user_directory search indexes, see # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md # + # 'prefer_local_users' defines whether to prioritise local users in + # search query results. If True, local users are more likely to appear above + # remote users when searching the user directory. Defaults to false. + # #user_directory: # enabled: true # search_all_users: false + # prefer_local_users: false + # + # # If this is set, user search will be delegated to this ID server instead + # # of synapse performing the search itself. + # # This is an experimental API. + # defer_to_id_server: https://id.example.com """ diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index f10e33f7b8..7a0ca16da8 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]: @attr.s class InstanceLocationConfig: - """The host and port to talk to an instance via HTTP replication. - """ + """The host and port to talk to an instance via HTTP replication.""" host = attr.ib(type=str) port = attr.ib(type=int) @@ -54,13 +53,19 @@ class WriterLocations: ) typing = attr.ib(default="master", type=str) to_device = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter, + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) account_data = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter, + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) receipts = attr.ib( - default=["master"], type=List[str], converter=_instance_to_list_converter, + default=["master"], + type=List[str], + converter=_instance_to_list_converter, ) @@ -107,7 +112,9 @@ class WorkerConfig(Config): if manhole: self.worker_listeners.append( ListenerConfig( - port=manhole, bind_addresses=["127.0.0.1"], type="manhole", + port=manhole, + bind_addresses=["127.0.0.1"], + type="manhole", ) ) diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 56f8dc9caf..4e20851d7f 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py
@@ -42,7 +42,7 @@ def check( do_sig_check: bool = True, do_size_check: bool = True, ) -> None: - """ Checks if this event is correctly authed. + """Checks if this event is correctly authed. Args: room_version_obj: the version of the room @@ -161,6 +161,7 @@ def check( if logger.isEnabledFor(logging.DEBUG): logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()]) + # 5. If type is m.room.membership if event.type == EventTypes.Member: _is_membership_change_allowed(event, auth_events) logger.debug("Allowing! %s", event) @@ -247,6 +248,7 @@ def _is_membership_change_allowed( caller_in_room = caller and caller.membership == Membership.JOIN caller_invited = caller and caller.membership == Membership.INVITE + caller_knocked = caller and caller.membership == Membership.KNOCK # get info about the target key = (EventTypes.Member, target_user_id) @@ -289,9 +291,12 @@ def _is_membership_change_allowed( raise AuthError(403, "%s is banned from the room" % (target_user_id,)) return - if Membership.JOIN != membership: + # Require the user to be in the room for membership changes other than join/knock. + if Membership.JOIN != membership and Membership.KNOCK != membership: + # If the user has been invited or has knocked, they are allowed to change their + # membership event to leave if ( - caller_invited + (caller_invited or caller_knocked) and Membership.LEAVE == membership and target_user_id == event.user_id ): @@ -324,7 +329,7 @@ def _is_membership_change_allowed( raise AuthError(403, "You are banned from this room") elif join_rule == JoinRules.PUBLIC: pass - elif join_rule == JoinRules.INVITE: + elif join_rule in (JoinRules.INVITE, JoinRules.KNOCK): if not caller_in_room and not caller_invited: raise AuthError(403, "You are not invited to this room.") else: @@ -343,6 +348,17 @@ def _is_membership_change_allowed( elif Membership.BAN == membership: if user_level < ban_level or user_level <= target_level: raise AuthError(403, "You don't have permission to ban") + elif Membership.KNOCK == membership: + if join_rule != JoinRules.KNOCK: + raise AuthError(403, "You don't have permission to knock") + elif target_user_id != event.user_id: + raise AuthError(403, "You cannot knock for other users") + elif target_in_room: + raise AuthError(403, "You cannot knock on a room you are already in") + elif caller_invited: + raise AuthError(403, "You are already invited to this room") + elif target_banned: + raise AuthError(403, "You are banned from this room") else: raise AuthError(500, "Unknown membership %s" % membership) @@ -423,7 +439,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool: def check_redaction( - room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], + room_version_obj: RoomVersion, + event: EventBase, + auth_events: StateMap[EventBase], ) -> bool: """Check whether the event sender is allowed to redact the target event. @@ -459,7 +477,9 @@ def check_redaction( def _check_power_levels( - room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase], + room_version_obj: RoomVersion, + event: EventBase, + auth_events: StateMap[EventBase], ) -> None: user_list = event.content.get("users", {}) # Validate users @@ -699,7 +719,7 @@ def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]: if event.type == EventTypes.Member: membership = event.content["membership"] - if membership in [Membership.JOIN, Membership.INVITE]: + if membership in [Membership.JOIN, Membership.INVITE, Membership.KNOCK]: auth_types.add((EventTypes.JoinRules, "")) auth_types.add((EventTypes.Member, event.state_key)) diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 07df258e6e..c1c0426f6e 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py
@@ -98,7 +98,9 @@ class EventBuilder: return self._state_key is not None async def build( - self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]], + self, + prev_event_ids: List[str], + auth_event_ids: Optional[List[str]], ) -> EventBase: """Transform into a fully signed and hashed event diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index afecafe15c..7295df74fe 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py
@@ -341,8 +341,7 @@ def _encode_state_dict(state_dict): def _decode_state_dict(input): - """Decodes a state dict encoded using `_encode_state_dict` above - """ + """Decodes a state dict encoded using `_encode_state_dict` above""" if input is None: return None diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index e7e3a7b9a4..063af7a81d 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py
@@ -17,6 +17,8 @@ import inspect from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from synapse.rest.media.v1._base import FileInfo +from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import Collection from synapse.util.async_helpers import maybe_awaitable @@ -63,16 +65,32 @@ class SpamChecker: return False async def user_may_invite( - self, inviter_userid: str, invitee_userid: str, room_id: str + self, + inviter_userid: str, + invitee_userid: Optional[str], + third_party_invite: Optional[Dict], + room_id: str, + new_room: bool, + published_room: bool, ) -> bool: """Checks if a given user may send an invite If this method returns false, the invite will be rejected. Args: - inviter_userid: The user ID of the sender of the invitation - invitee_userid: The user ID targeted in the invitation - room_id: The room ID + inviter_userid: + invitee_userid: The user ID of the invitee. Is None + if this is a third party invite and the 3PID is not bound to a + user ID. + third_party_invite: If a third party invite then is a + dict containing the medium and address of the invitee. + room_id: + new_room: Whether the user is being invited to the room as + part of a room creation, if so the invitee would have been + included in the call to `user_may_create_room`. + published_room: Whether the room the user is being invited + to has been published in the local homeserver's public room + directory. Returns: True if the user may send an invite, otherwise False @@ -81,7 +99,12 @@ class SpamChecker: if ( await maybe_awaitable( spam_checker.user_may_invite( - inviter_userid, invitee_userid, room_id + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, + published_room, ) ) is False @@ -90,20 +113,36 @@ class SpamChecker: return True - async def user_may_create_room(self, userid: str) -> bool: + async def user_may_create_room( + self, + userid: str, + invite_list: List[str], + third_party_invite_list: List[Dict], + cloning: bool, + ) -> bool: """Checks if a given user may create a room If this method returns false, the creation request will be rejected. Args: userid: The ID of the user attempting to create a room + invite_list: List of user IDs that would be invited to + the new room. + third_party_invite_list: List of third party invites + for the new room. + cloning: Whether the user is cloning an existing room, e.g. + upgrading a room. Returns: True if the user may create a room, otherwise False """ for spam_checker in self.spam_checkers: if ( - await maybe_awaitable(spam_checker.user_may_create_room(userid)) + await maybe_awaitable( + spam_checker.user_may_create_room( + userid, invite_list, third_party_invite_list, cloning + ) + ) is False ): return False @@ -156,6 +195,25 @@ class SpamChecker: return True + def user_may_join_room(self, userid: str, room_id: str, is_invited: bool): + """Checks if a given users is allowed to join a room. + + Not called when a user creates a room. + + Args: + userid: + room_id: + is_invited: Whether the user is invited into the room + + Returns: + bool: Whether the user may join the room + """ + for spam_checker in self.spam_checkers: + if spam_checker.user_may_join_room(userid, room_id, is_invited) is False: + return False + + return True + async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool: """Checks if a user ID or display name are considered "spammy" by this server. @@ -214,3 +272,48 @@ class SpamChecker: return behaviour return RegistrationBehaviour.ALLOW + + async def check_media_file_for_spam( + self, file_wrapper: ReadableFileWrapper, file_info: FileInfo + ) -> bool: + """Checks if a piece of newly uploaded media should be blocked. + + This will be called for local uploads, downloads of remote media, each + thumbnail generated for those, and web pages/images used for URL + previews. + + Note that care should be taken to not do blocking IO operations in the + main thread. For example, to get the contents of a file a module + should do:: + + async def check_media_file_for_spam( + self, file: ReadableFileWrapper, file_info: FileInfo + ) -> bool: + buffer = BytesIO() + await file.write_chunks_to(buffer.write) + + if buffer.getvalue() == b"Hello World": + return True + + return False + + + Args: + file: An object that allows reading the contents of the media. + file_info: Metadata about the file. + + Returns: + True if the media should be blocked or False if it should be + allowed. + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_media_file_for_spam", None) + if checker: + spam = await maybe_awaitable(checker(file_wrapper, file_info)) + if spam: + return True + + return False diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 77fbd3f68a..02bce8b5c9 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py
@@ -40,7 +40,8 @@ class ThirdPartyEventRules: if module is not None: self.third_party_rules = module( - config=config, module_api=hs.get_module_api(), + config=config, + module_api=hs.get_module_api(), ) async def check_event_allowed( diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9c22e33813..e27af490c2 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py
@@ -34,7 +34,7 @@ SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.") def prune_event(event: EventBase) -> EventBase: - """ Returns a pruned version of the given event, which removes all keys we + """Returns a pruned version of the given event, which removes all keys we don't know about or think could potentially be dodgy. This is used when we "redact" an event. We want to remove all fields that @@ -240,6 +240,7 @@ def format_event_for_client_v1(d): "replaces_state", "prev_content", "invite_room_state", + "knock_room_state", ) for key in copy_keys: if key in d["unsigned"]: @@ -276,7 +277,7 @@ def serialize_event( event_format=format_event_for_client_v1, token_id=None, only_event_fields=None, - is_invite=False, + include_stripped_room_state=False, ): """Serialize event for clients @@ -287,8 +288,10 @@ def serialize_event( event_format token_id only_event_fields - is_invite (bool): Whether this is an invite that is being sent to the - invitee + include_stripped_room_state (bool): Some events can have stripped room state + stored in the `unsigned` field. This is required for invite and knock + functionality. If this option is False, that state will be removed from the + event before it is returned. Otherwise, it will be kept. Returns: dict @@ -320,11 +323,13 @@ def serialize_event( if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id - # If this is an invite for somebody else, then we don't care about the - # invite_room_state as that's meant solely for the invitee. Other clients - # will already have the state since they're in the room. - if not is_invite: + # invite_room_state and knock_room_state are a list of stripped room state events + # that are meant to provide metadata about a room to an invitee/knocker. They are + # intended to only be included in specific circumstances, such as down sync, and + # should not be included in any other case. + if not include_stripped_room_state: d["unsigned"].pop("invite_room_state", None) + d["unsigned"].pop("knock_room_state", None) if as_client_event: d = event_format(d) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 40e1451201..6b1f7dcb89 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyrignt 2020 Sorunome +# Copyrignt 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -553,7 +555,7 @@ class FederationClient(FederationBase): RuntimeError: if no servers were reachable. """ - valid_memberships = {Membership.JOIN, Membership.LEAVE} + valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" @@ -750,7 +752,11 @@ class FederationClient(FederationBase): return resp[1] async def send_invite( - self, destination: str, room_id: str, event_id: str, pdu: EventBase, + self, + destination: str, + room_id: str, + event_id: str, + pdu: EventBase, ) -> EventBase: room_version = await self.store.get_room_version(room_id) @@ -888,6 +894,62 @@ class FederationClient(FederationBase): # content. return resp[1] + async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict: + """Attempts to send a knock event to given a list of servers. Iterates + through the list until one attempt succeeds. + + Doing so will cause the remote server to add the event to the graph, + and send the event out to the rest of the federation. + + Args: + destinations: A list of candidate homeservers which are likely to be + participating in the room. + pdu: The event to be sent. + + Returns: + The remote homeserver return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + + Raises: + SynapseError: If the chosen remote server returns a 3xx/4xx code. + RuntimeError: If no servers were reachable. + """ + + async def send_request(destination: str) -> JsonDict: + return await self._do_send_knock(destination, pdu) + + return await self._try_destination_list( + "xyz.amorgan.knock/send_knock", destinations, send_request + ) + + async def _do_send_knock(self, destination: str, pdu: EventBase) -> JsonDict: + """Send a knock event to a remote homeserver. + + Args: + destination: The homeserver to send to. + pdu: The event to send. + + Returns: + The remote homeserver can optionally return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + """ + time_now = self._clock.time_msec() + + return await self.transport_layer.send_knock_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + async def get_public_rooms( self, remote_server: str, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 171d25c945..6337accc70 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -45,6 +45,7 @@ from synapse.api.errors import ( UnsupportedRoomVersionError, ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.config.api import DEFAULT_ROOM_STATE_TYPES from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions @@ -85,7 +86,8 @@ received_queries_counter = Counter( ) pdu_process_time = Histogram( - "synapse_federation_server_pdu_process_time", "Time taken to process an event", + "synapse_federation_server_pdu_process_time", + "Time taken to process an event", ) @@ -204,7 +206,7 @@ class FederationServer(FederationBase): async def _handle_incoming_transaction( self, origin: str, transaction: Transaction, request_time: int ) -> Tuple[int, Dict[str, Any]]: - """ Process an incoming transaction and return the HTTP response + """Process an incoming transaction and return the HTTP response Args: origin: the server making the request @@ -373,8 +375,7 @@ class FederationServer(FederationBase): return pdu_results async def _handle_edus_in_txn(self, origin: str, transaction: Transaction): - """Process the EDUs in a received transaction. - """ + """Process the EDUs in a received transaction.""" async def _process_edu(edu_dict): received_edus_counter.inc() @@ -437,7 +438,10 @@ class FederationServer(FederationBase): raise AuthError(403, "Host not in room.") resp = await self._state_ids_resp_cache.wrap( - (room_id, event_id), self._on_state_ids_request_compute, room_id, event_id, + (room_id, event_id), + self._on_state_ids_request_compute, + room_id, + event_id, ) return 200, resp @@ -567,6 +571,76 @@ class FederationServer(FederationBase): await self.handler.on_send_leave_request(origin, pdu) return {} + async def on_make_knock_request( + self, origin: str, room_id: str, user_id: str, supported_versions: List[str] + ) -> Dict[str, Union[EventBase, str]]: + """We've received a /make_knock/ request, so we create a partial knock + event for the room and hand that back, along with the room version, to the knocking + homeserver. We do *not* persist or process this event until the other server has + signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: The room to create the knock event in. + user_id: The user to create the knock for. + supported_versions: The room versions supported by the requesting server. + + Returns: + The partial knock event. + """ + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, room_id) + + room_version = await self.store.get_room_version_id(room_id) + if room_version not in supported_versions: + logger.warning( + "Room version %s not in %s", room_version, supported_versions + ) + raise IncompatibleRoomVersionError(room_version=room_version) + + pdu = await self.handler.on_make_knock_request(origin, room_id, user_id) + time_now = self._clock.time_msec() + return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + + async def on_send_knock_request( + self, origin: str, content: JsonDict, room_id: str, + ) -> Dict[str, List[JsonDict]]: + """ + We have received a knock event for a room. Verify and send the event into the room + on the knocking homeserver's behalf. Then reply with some stripped state from the + room for the knockee. + + Args: + origin: The remote homeserver of the knocking user. + content: The content of the request. + room_id: The ID of the room to knock on. + + Returns: + The stripped room state. + """ + logger.debug("on_send_knock_request: content: %s", content) + + room_version = await self.store.get_room_version(room_id) + pdu = event_from_pdu_json(content, room_version) + + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, pdu.room_id) + + logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures) + + pdu = await self._check_sigs_and_hash(room_version, pdu) + + # Handle the event, and retrieve the EventContext + event_context = await self.handler.on_send_knock_request(origin, pdu) + + # Retrieve stripped state events from the room and send them back to the remote + # server. This will allow the remote server's clients to display information + # related to the room while the knock request is pending. + stripped_room_state = await self.store.get_stripped_room_state_from_event_context( + event_context, DEFAULT_ROOM_STATE_TYPES + ) + return {"knock_state_events": stripped_room_state} + async def on_event_auth( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: @@ -679,7 +753,7 @@ class FederationServer(FederationBase): ) async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None: - """ Process a PDU received in a federation /send/ transaction. + """Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. (The error will then be logged and sent back to the sender (which @@ -906,13 +980,11 @@ class FederationHandlerRegistry: self.query_handlers[query_type] = handler def register_instance_for_edu(self, edu_type: str, instance_name: str): - """Register that the EDU handler is on a different instance than master. - """ + """Register that the EDU handler is on a different instance than master.""" self._edu_type_to_instance[edu_type] = [instance_name] def register_instances_for_edu(self, edu_type: str, instance_names: List[str]): - """Register that the EDU handler is on multiple instances. - """ + """Register that the EDU handler is on multiple instances.""" self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 079e2b2fe0..ce5fc758f0 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py
@@ -30,8 +30,7 @@ logger = logging.getLogger(__name__) class TransactionActions: - """ Defines persistence actions that relate to handling Transactions. - """ + """Defines persistence actions that relate to handling Transactions.""" def __init__(self, datastore): self.store = datastore @@ -57,8 +56,7 @@ class TransactionActions: async def set_response( self, origin: str, transaction: Transaction, code: int, response: JsonDict ) -> None: - """Persist how we responded to a transaction. - """ + """Persist how we responded to a transaction.""" transaction_id = transaction.transaction_id # type: ignore if not transaction_id: raise RuntimeError("Cannot persist a transaction with no transaction_id") diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5f1bf492c1..3e993b428b 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py
@@ -468,8 +468,7 @@ class KeyedEduRow( class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu - """Streams EDUs that don't have keys. See KeyedEduRow - """ + """Streams EDUs that don't have keys. See KeyedEduRow""" TypeId = "e" @@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows): # them into the appropriate collection and then send them off. buff = ParsedFederationStreamData( - presence=[], presence_destinations=[], keyed_edus={}, edus={}, + presence=[], + presence_destinations=[], + keyed_edus={}, + edus={}, ) # Parse the rows in the stream and add to the buffer diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 643b26ae6d..97fc4d0a82 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py
@@ -328,7 +328,9 @@ class FederationSender: # to allow us to perform catch-up later on if the remote is unreachable # for a while. await self.store.store_destination_rooms_entries( - destinations, pdu.room_id, pdu.internal_metadata.stream_ordering, + destinations, + pdu.room_id, + pdu.internal_metadata.stream_ordering, ) for destination in destinations: @@ -475,7 +477,7 @@ class FederationSender: self, states: List[UserPresenceState], destinations: List[str] ) -> None: """Send the given presence states to the given destinations. - destinations (list[str]) + destinations (list[str]) """ if not states or not self.hs.config.use_presence: @@ -616,8 +618,8 @@ class FederationSender: last_processed = None # type: Optional[str] while True: - destinations_to_wake = await self.store.get_catch_up_outstanding_destinations( - last_processed + destinations_to_wake = ( + await self.store.get_catch_up_outstanding_destinations(last_processed) ) if not destinations_to_wake: diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index db8e456fe8..deb519f3ef 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py
@@ -85,7 +85,8 @@ class PerDestinationQueue: # processing. We have a guard in `attempt_new_transaction` that # ensure we don't start sending stuff. logger.error( - "Create a per destination queue for %s on wrong worker", destination, + "Create a per destination queue for %s on wrong worker", + destination, ) self._should_send_on_this_instance = False @@ -440,8 +441,10 @@ class PerDestinationQueue: if first_catch_up_check: # first catchup so get last_successful_stream_ordering from database - self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering( - self._destination + self._last_successful_stream_ordering = ( + await self._store.get_destination_last_successful_stream_ordering( + self._destination + ) ) if self._last_successful_stream_ordering is None: @@ -457,7 +460,8 @@ class PerDestinationQueue: # get at most 50 catchup room/PDUs while True: event_ids = await self._store.get_catch_up_room_event_ids( - self._destination, self._last_successful_stream_ordering, + self._destination, + self._last_successful_stream_ordering, ) if not event_ids: diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3e07f925e0..763aff296c 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py
@@ -65,7 +65,10 @@ class TransactionManager: @measure_func("_send_new_transaction") async def send_new_transaction( - self, destination: str, pdus: List[EventBase], edus: List[Edu], + self, + destination: str, + pdus: List[EventBase], + edus: List[Edu], ) -> bool: """ Args: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index abe9168c78..68d9349a30 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py
@@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +18,7 @@ import logging import urllib -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError @@ -26,6 +28,7 @@ from synapse.api.urls import ( FEDERATION_V2_PREFIX, ) from synapse.logging.utils import log_function +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -39,7 +42,7 @@ class TransportLayerClient: @log_function def get_room_state_ids(self, destination, room_id, event_id): - """ Requests all state for a given room from the given server at the + """Requests all state for a given room from the given server at the given event. Returns the state's event_id's Args: @@ -63,7 +66,7 @@ class TransportLayerClient: @log_function def get_event(self, destination, event_id, timeout=None): - """ Requests the pdu with give id and origin from the given server. + """Requests the pdu with give id and origin from the given server. Args: destination (str): The host name of the remote homeserver we want @@ -84,7 +87,7 @@ class TransportLayerClient: @log_function def backfill(self, destination, room_id, event_tuples, limit): - """ Requests `limit` previous PDUs in a given context before list of + """Requests `limit` previous PDUs in a given context before list of PDUs. Args: @@ -118,7 +121,7 @@ class TransportLayerClient: @log_function async def send_transaction(self, transaction, json_data_callback=None): - """ Sends the given Transaction to its destination + """Sends the given Transaction to its destination Args: transaction (Transaction) @@ -209,13 +212,24 @@ class TransportLayerClient: Fails with ``FederationDeniedError`` if the remote destination is not in our federation whitelist """ - valid_memberships = {Membership.JOIN, Membership.LEAVE} + valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK} if membership not in valid_memberships: raise RuntimeError( "make_membership_event called with membership='%s', must be one of %s" % (membership, ",".join(valid_memberships)) ) - path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) + + # Knock currently uses an unstable prefix + if membership == Membership.KNOCK: + # Create a path in the form of /unstable/xyz.amorgan.knock/make_knock/... + path = _create_path( + FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock", + "/make_knock/%s/%s", + room_id, + user_id, + ) + else: + path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) ignore_backoff = False retry_on_dns_fail = False @@ -294,6 +308,41 @@ class TransportLayerClient: return response @log_function + async def send_knock_v2( + self, destination: str, room_id: str, event_id: str, content: JsonDict, + ) -> JsonDict: + """ + Sends a signed knock membership event to a remote server. This is the second + step for knocking after make_knock. + + Args: + destination: The remote homeserver. + room_id: The ID of the room to knock on. + event_id: The ID of the knock membership event that we're sending. + content: The knock membership event that we're sending. Note that this is not the + `content` field of the membership event, but the entire signed membership event + itself represented as a JSON dict. + + Returns: + The remote homeserver can optionally return some state from the room. The response + dictionary is in the form: + + {"knock_state_events": [<state event dict>, ...]} + + The list of state events may be empty. + """ + path = _create_path( + FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock", + "/send_knock/%s/%s", + room_id, + event_id, + ) + + return await self.client.put_json( + destination=destination, path=path, data=content + ) + + @log_function async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) @@ -551,8 +600,7 @@ class TransportLayerClient: @log_function def get_group_profile(self, destination, group_id, requester_user_id): - """Get a group profile - """ + """Get a group profile""" path = _create_v1_path("/groups/%s/profile", group_id) return self.client.get_json( @@ -584,8 +632,7 @@ class TransportLayerClient: @log_function def get_group_summary(self, destination, group_id, requester_user_id): - """Get a group summary - """ + """Get a group summary""" path = _create_v1_path("/groups/%s/summary", group_id) return self.client.get_json( @@ -597,8 +644,7 @@ class TransportLayerClient: @log_function def get_rooms_in_group(self, destination, group_id, requester_user_id): - """Get all rooms in a group - """ + """Get all rooms in a group""" path = _create_v1_path("/groups/%s/rooms", group_id) return self.client.get_json( @@ -611,8 +657,7 @@ class TransportLayerClient: def add_room_to_group( self, destination, group_id, requester_user_id, room_id, content ): - """Add a room to a group - """ + """Add a room to a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.post_json( @@ -626,8 +671,7 @@ class TransportLayerClient: def update_room_in_group( self, destination, group_id, requester_user_id, room_id, config_key, content ): - """Update room in group - """ + """Update room in group""" path = _create_v1_path( "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) @@ -641,8 +685,7 @@ class TransportLayerClient: ) def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): - """Remove a room from a group - """ + """Remove a room from a group""" path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.delete_json( @@ -654,8 +697,7 @@ class TransportLayerClient: @log_function def get_users_in_group(self, destination, group_id, requester_user_id): - """Get users in a group - """ + """Get users in a group""" path = _create_v1_path("/groups/%s/users", group_id) return self.client.get_json( @@ -667,8 +709,7 @@ class TransportLayerClient: @log_function def get_invited_users_in_group(self, destination, group_id, requester_user_id): - """Get users that have been invited to a group - """ + """Get users that have been invited to a group""" path = _create_v1_path("/groups/%s/invited_users", group_id) return self.client.get_json( @@ -680,8 +721,7 @@ class TransportLayerClient: @log_function def accept_group_invite(self, destination, group_id, user_id, content): - """Accept a group invite - """ + """Accept a group invite""" path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) return self.client.post_json( @@ -690,8 +730,7 @@ class TransportLayerClient: @log_function def join_group(self, destination, group_id, user_id, content): - """Attempts to join a group - """ + """Attempts to join a group""" path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( @@ -702,8 +741,7 @@ class TransportLayerClient: def invite_to_group( self, destination, group_id, user_id, requester_user_id, content ): - """Invite a user to a group - """ + """Invite a user to a group""" path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) return self.client.post_json( @@ -730,8 +768,7 @@ class TransportLayerClient: def remove_user_from_group( self, destination, group_id, requester_user_id, user_id, content ): - """Remove a user from a group - """ + """Remove a user from a group""" path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) return self.client.post_json( @@ -772,8 +809,7 @@ class TransportLayerClient: def update_group_summary_room( self, destination, group_id, user_id, room_id, category_id, content ): - """Update a room entry in a group summary - """ + """Update a room entry in a group summary""" if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", @@ -796,8 +832,7 @@ class TransportLayerClient: def delete_group_summary_room( self, destination, group_id, user_id, room_id, category_id ): - """Delete a room entry in a group summary - """ + """Delete a room entry in a group summary""" if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", @@ -817,8 +852,7 @@ class TransportLayerClient: @log_function def get_group_categories(self, destination, group_id, requester_user_id): - """Get all categories in a group - """ + """Get all categories in a group""" path = _create_v1_path("/groups/%s/categories", group_id) return self.client.get_json( @@ -830,8 +864,7 @@ class TransportLayerClient: @log_function def get_group_category(self, destination, group_id, requester_user_id, category_id): - """Get category info in a group - """ + """Get category info in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.get_json( @@ -845,8 +878,7 @@ class TransportLayerClient: def update_group_category( self, destination, group_id, requester_user_id, category_id, content ): - """Update a category in a group - """ + """Update a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.post_json( @@ -861,8 +893,7 @@ class TransportLayerClient: def delete_group_category( self, destination, group_id, requester_user_id, category_id ): - """Delete a category in a group - """ + """Delete a category in a group""" path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.delete_json( @@ -874,8 +905,7 @@ class TransportLayerClient: @log_function def get_group_roles(self, destination, group_id, requester_user_id): - """Get all roles in a group - """ + """Get all roles in a group""" path = _create_v1_path("/groups/%s/roles", group_id) return self.client.get_json( @@ -887,8 +917,7 @@ class TransportLayerClient: @log_function def get_group_role(self, destination, group_id, requester_user_id, role_id): - """Get a roles info - """ + """Get a roles info""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.get_json( @@ -902,8 +931,7 @@ class TransportLayerClient: def update_group_role( self, destination, group_id, requester_user_id, role_id, content ): - """Update a role in a group - """ + """Update a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.post_json( @@ -916,8 +944,7 @@ class TransportLayerClient: @log_function def delete_group_role(self, destination, group_id, requester_user_id, role_id): - """Delete a role in a group - """ + """Delete a role in a group""" path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.delete_json( @@ -931,8 +958,7 @@ class TransportLayerClient: def update_group_summary_user( self, destination, group_id, requester_user_id, user_id, role_id, content ): - """Update a users entry in a group - """ + """Update a users entry in a group""" if role_id: path = _create_v1_path( "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id @@ -950,8 +976,7 @@ class TransportLayerClient: @log_function def set_group_join_policy(self, destination, group_id, requester_user_id, content): - """Sets the join policy for a group - """ + """Sets the join policy for a group""" path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) return self.client.put_json( @@ -966,8 +991,7 @@ class TransportLayerClient: def delete_group_summary_user( self, destination, group_id, requester_user_id, user_id, role_id ): - """Delete a users entry in a group - """ + """Delete a users entry in a group""" if role_id: path = _create_v1_path( "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id @@ -983,8 +1007,7 @@ class TransportLayerClient: ) def bulk_get_publicised_groups(self, destination, user_ids): - """Get the groups a list of users are publicising - """ + """Get the groups a list of users are publicising""" path = _create_v1_path("/get_groups_publicised") @@ -1004,6 +1027,20 @@ class TransportLayerClient: return self.client.get_json(destination=destination, path=path) + def get_info_of_users(self, destination: str, user_ids: List[str]): + """ + Args: + destination: The remote server + user_ids: A list of user IDs to query info about + + Returns: + Deferred[List]: A dictionary of User ID to information about that user. + """ + path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/users/info") + data = {"user_ids": user_ids} + + return self.client.post_json(destination=destination, path=path, data=data) + def _create_path(federation_prefix, path, *args): """ diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 95c64510a9..8d89074c5a 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py
@@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,13 +15,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import functools import logging import re from typing import Optional, Tuple, Type import synapse +from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH from synapse.api.errors import Codes, FederationDeniedError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import ( @@ -30,9 +31,11 @@ from synapse.api.urls import ( ) from synapse.http.server import JsonResource from synapse.http.servlet import ( + assert_params_in_dict, parse_boolean_from_args, parse_integer_from_args, parse_json_object_from_request, + parse_list_from_args, parse_string_from_args, ) from synapse.logging.context import run_in_background @@ -364,7 +367,10 @@ class BaseFederationServlet: continue server.register_paths( - method, (pattern,), self._wrap(code), self.__class__.__name__, + method, + (pattern,), + self._wrap(code), + self.__class__.__name__, ) @@ -381,7 +387,7 @@ class FederationSendServlet(BaseFederationServlet): # This is when someone is trying to send us a bunch of data. async def on_PUT(self, origin, content, query, transaction_id): - """ Called on PUT /send/<transaction_id>/ + """Called on PUT /send/<transaction_id>/ Args: request (twisted.web.http.Request): The HTTP request. @@ -542,6 +548,34 @@ class FederationV2SendLeaveServlet(BaseFederationServlet): return 200, content +class FederationMakeKnockServlet(BaseFederationServlet): + PATH = "/make_knock/(?P<room_id>[^/]*)/(?P<user_id>[^/]*)" + + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock" + + async def on_GET(self, origin, content, query, room_id, user_id): + try: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_list_from_args(query, "ver", encoding="utf-8") + except KeyError: + raise SynapseError(400, "Missing required query parameter 'ver'") + + content = await self.handler.on_make_knock_request( + origin, room_id, user_id, supported_versions=supported_versions + ) + return 200, content + + +class FederationV2SendKnockServlet(BaseFederationServlet): + PATH = "/send_knock/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_UNSTABLE_PREFIX + "/xyz.amorgan.knock" + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, content + + class FederationEventAuthServlet(BaseFederationServlet): PATH = "/event_auth/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" @@ -842,6 +876,57 @@ class PublicRoomList(BaseFederationServlet): return 200, data +class FederationUserInfoServlet(BaseFederationServlet): + """ + Return information about a set of users. + + This API returns expiration and deactivation information about a set of + users. Requested users not local to this homeserver will be ignored. + + Example request: + POST /users/info + + { + "user_ids": [ + "@alice:example.com", + "@bob:example.com" + ] + } + + Example response + { + "@alice:example.com": { + "expired": false, + "deactivated": true + } + } + """ + + PATH = "/users/info" + PREFIX = FEDERATION_UNSTABLE_PREFIX + + def __init__(self, handler, authenticator, ratelimiter, server_name): + super(FederationUserInfoServlet, self).__init__( + handler, authenticator, ratelimiter, server_name + ) + self.handler = handler + + async def on_POST(self, origin, content, query): + assert_params_in_dict(content, required=["user_ids"]) + + user_ids = content.get("user_ids", []) + + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + data = await self.handler.store.get_info_for_users(user_ids) + return 200, data + + class FederationVersionServlet(BaseFederationServlet): PATH = "/version" @@ -855,8 +940,7 @@ class FederationVersionServlet(BaseFederationServlet): class FederationGroupsProfileServlet(BaseFederationServlet): - """Get/set the basic profile of a group on behalf of a user - """ + """Get/set the basic profile of a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/profile" @@ -895,8 +979,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): class FederationGroupsRoomsServlet(BaseFederationServlet): - """Get the rooms in a group on behalf of a user - """ + """Get the rooms in a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/rooms" @@ -911,8 +994,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet): - """Add/remove room from group - """ + """Add/remove room from group""" PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" @@ -940,8 +1022,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): - """Update room config in group - """ + """Update room config in group""" PATH = ( "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" @@ -961,8 +1042,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet): - """Get the users in a group on behalf of a user - """ + """Get the users in a group on behalf of a user""" PATH = "/groups/(?P<group_id>[^/]*)/users" @@ -977,8 +1057,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): - """Get the users that have been invited to a group - """ + """Get the users that have been invited to a group""" PATH = "/groups/(?P<group_id>[^/]*)/invited_users" @@ -995,8 +1074,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): class FederationGroupsInviteServlet(BaseFederationServlet): - """Ask a group server to invite someone to the group - """ + """Ask a group server to invite someone to the group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @@ -1013,8 +1091,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): - """Accept an invitation from the group server - """ + """Accept an invitation from the group server""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" @@ -1028,8 +1105,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): class FederationGroupsJoinServlet(BaseFederationServlet): - """Attempt to join a group - """ + """Attempt to join a group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" @@ -1043,8 +1119,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet): - """Leave or kick a user from the group - """ + """Leave or kick a user from the group""" PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @@ -1061,8 +1136,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsLocalInviteServlet(BaseFederationServlet): - """A group server has invited a local user - """ + """A group server has invited a local user""" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @@ -1076,8 +1150,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): - """A group server has removed a local user - """ + """A group server has removed a local user""" PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @@ -1093,8 +1166,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation - """ + """A group or user's server renews their attestation""" PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" @@ -1128,7 +1200,17 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") + raise SynapseError( + 400, "category_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) resp = await self.handler.update_group_summary_room( group_id, @@ -1156,8 +1238,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): class FederationGroupsCategoriesServlet(BaseFederationServlet): - """Get all categories for a group - """ + """Get all categories for a group""" PATH = "/groups/(?P<group_id>[^/]*)/categories/?" @@ -1172,8 +1253,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): class FederationGroupsCategoryServlet(BaseFederationServlet): - """Add/remove/get a category in a group - """ + """Add/remove/get a category in a group""" PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" @@ -1196,6 +1276,14 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): if category_id == "": raise SynapseError(400, "category_id cannot be empty string") + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + resp = await self.handler.upsert_group_category( group_id, requester_user_id, category_id, content ) @@ -1218,8 +1306,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): class FederationGroupsRolesServlet(BaseFederationServlet): - """Get roles in a group - """ + """Get roles in a group""" PATH = "/groups/(?P<group_id>[^/]*)/roles/?" @@ -1234,8 +1321,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): class FederationGroupsRoleServlet(BaseFederationServlet): - """Add/remove/get a role in a group - """ + """Add/remove/get a role in a group""" PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" @@ -1254,7 +1340,17 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") + raise SynapseError( + 400, "role_id cannot be empty string", Codes.INVALID_PARAM + ) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) resp = await self.handler.update_group_role( group_id, requester_user_id, role_id, content @@ -1299,6 +1395,14 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): if role_id == "": raise SynapseError(400, "role_id cannot be empty string") + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + resp = await self.handler.update_group_summary_user( group_id, requester_user_id, @@ -1325,8 +1429,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): - """Get roles in a group - """ + """Get roles in a group""" PATH = "/get_groups_publicised" @@ -1339,8 +1442,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): - """Sets whether a group is joinable without an invite or knock - """ + """Sets whether a group is joinable without an invite or knock""" PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" @@ -1387,11 +1489,13 @@ FEDERATION_SERVLET_CLASSES = ( FederationQueryServlet, FederationMakeJoinServlet, FederationMakeLeaveServlet, + FederationMakeKnockServlet, FederationEventServlet, FederationV1SendJoinServlet, FederationV2SendJoinServlet, FederationV1SendLeaveServlet, FederationV2SendLeaveServlet, + FederationV2SendKnockServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationGetMissingEventsServlet, @@ -1403,6 +1507,7 @@ FEDERATION_SERVLET_CLASSES = ( On3pidBindServlet, FederationVersionServlet, RoomComplexityServlet, + FederationUserInfoServlet, ) # type: Tuple[Type[BaseFederationServlet], ...] OPENID_SERVLET_CLASSES = ( diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 64d98fc8f6..b662c42621 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True) class Edu(JsonEncodedObject): - """ An Edu represents a piece of data sent from one homeserver to another. + """An Edu represents a piece of data sent from one homeserver to another. In comparison to Pdus, Edus are not persisted for a long time on disk, are not meaningful beyond a given pair of homeservers, and don't have an @@ -63,7 +63,7 @@ class Edu(JsonEncodedObject): class Transaction(JsonEncodedObject): - """ A transaction is a list of Pdus and Edus to be sent to a remote home + """A transaction is a list of Pdus and Edus to be sent to a remote home server with some extra metadata. Example transaction:: @@ -99,7 +99,7 @@ class Transaction(JsonEncodedObject): ] def __init__(self, transaction_id=None, pdus=[], **kwargs): - """ If we include a list of pdus then we decode then as PDU's + """If we include a list of pdus then we decode then as PDU's automatically. """ @@ -111,7 +111,7 @@ class Transaction(JsonEncodedObject): @staticmethod def create_new(pdus, **kwargs): - """ Used to create a new transaction. Will auto fill out + """Used to create a new transaction. Will auto fill out transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 41cf07cc88..a3f8d92d08 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py
@@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like: import logging import random -from typing import Tuple +from typing import TYPE_CHECKING, Optional, Tuple from signedjson.sign import sign_json from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import get_domain_from_id +from synapse.types import JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -61,18 +64,21 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 class GroupAttestationSigning: - """Creates and verifies group attestations. - """ + """Creates and verifies group attestations.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.keyring = hs.get_keyring() self.clock = hs.get_clock() self.server_name = hs.hostname self.signing_key = hs.signing_key async def verify_attestation( - self, attestation, group_id, user_id, server_name=None - ): + self, + attestation: JsonDict, + group_id: str, + user_id: str, + server_name: Optional[str] = None, + ) -> None: """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -101,16 +107,18 @@ class GroupAttestationSigning: if valid_until_ms < now: raise SynapseError(400, "Attestation expired") + assert server_name is not None await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) - def create_attestation(self, group_id, user_id): + def create_attestation(self, group_id: str, user_id: str) -> JsonDict: """Create an attestation for the group_id and user_id with default validity length. """ - validity_period = DEFAULT_ATTESTATION_LENGTH_MS - validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER) + validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform( + *DEFAULT_ATTESTATION_JITTER + ) valid_until_ms = int(self.clock.time_msec() + validity_period) return sign_json( @@ -125,10 +133,9 @@ class GroupAttestationSigning: class GroupAttestionRenewer: - """Responsible for sending and receiving attestation updates. - """ + """Responsible for sending and receiving attestation updates.""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastore() self.assestations = hs.get_groups_attestation_signing() @@ -141,9 +148,10 @@ class GroupAttestionRenewer: self._start_renew_attestations, 30 * 60 * 1000 ) - async def on_renew_attestation(self, group_id, user_id, content): - """When a remote updates an attestation - """ + async def on_renew_attestation( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: + """When a remote updates an attestation""" attestation = content["attestation"] if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): @@ -157,12 +165,11 @@ class GroupAttestionRenewer: return {} - def _start_renew_attestations(self): + def _start_renew_attestations(self) -> None: return run_as_background_process("renew_attestations", self._renew_attestations) - async def _renew_attestations(self): - """Called periodically to check if we need to update any of our attestations - """ + async def _renew_attestations(self) -> None: + """Called periodically to check if we need to update any of our attestations""" now = self.clock.time_msec() @@ -170,7 +177,7 @@ class GroupAttestionRenewer: now + UPDATE_ATTESTATION_TIME_MS ) - async def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]) -> None: group_id, user_id = group_user try: if not self.is_mine_id(group_id): diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0d042cbfac..f9a0f40221 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py
@@ -16,11 +16,17 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, SynapseError -from synapse.types import GroupID, RoomID, UserID, get_domain_from_id +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN +from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -32,8 +38,13 @@ logger = logging.getLogger(__name__) # TODO: Flairs +# Note that the maximum lengths are somewhat arbitrary. +MAX_SHORT_DESC_LEN = 1000 +MAX_LONG_DESC_LEN = 10000 + + class GroupsServerWorkerHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.room_list_handler = hs.get_room_list_handler() @@ -48,16 +59,21 @@ class GroupsServerWorkerHandler: self.profile_handler = hs.get_profile_handler() async def check_group_is_ours( - self, group_id, requester_user_id, and_exists=False, and_is_admin=None - ): + self, + group_id: str, + requester_user_id: str, + and_exists: bool = False, + and_is_admin: Optional[str] = None, + ) -> Optional[dict]: """Check that the group is ours, and optionally if it exists. If group does exist then return group. Args: - group_id (str) - and_exists (bool): whether to also check if group exists - and_is_admin (str): whether to also check if given str is a user_id + group_id: The group ID to check. + requester_user_id: The user ID of the requester. + and_exists: whether to also check if group exists + and_is_admin: whether to also check if given str is a user_id that is an admin """ if not self.is_mine_id(group_id): @@ -80,7 +96,9 @@ class GroupsServerWorkerHandler: return group - async def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the summary for a group as seen by requester_user_id. The group summary consists of the profile of the room, and a curated @@ -113,6 +131,8 @@ class GroupsServerWorkerHandler: entry = await self.room_list_handler.generate_room_entry( room_id, len(joined_users), with_alias=False, allow_private=True ) + if entry is None: + continue entry = dict(entry) # so we don't change what's cached entry.pop("room_id", None) @@ -120,22 +140,22 @@ class GroupsServerWorkerHandler: rooms.sort(key=lambda e: e.get("order", 0)) - for entry in users: - user_id = entry["user_id"] + for user in users: + user_id = user["user_id"] if not self.is_mine_id(requester_user_id): attestation = await self.store.get_remote_attestation(group_id, user_id) if not attestation: continue - entry["attestation"] = attestation + user["attestation"] = attestation else: - entry["attestation"] = self.attestations.create_attestation( + user["attestation"] = self.attestations.create_attestation( group_id, user_id ) user_profile = await self.profile_handler.get_profile_from_cache(user_id) - entry.update(user_profile) + user.update(user_profile) users.sort(key=lambda e: e.get("order", 0)) @@ -158,46 +178,44 @@ class GroupsServerWorkerHandler: "user": membership_info, } - async def get_group_categories(self, group_id, requester_user_id): - """Get all categories in a group (as seen by user) - """ + async def get_group_categories( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get all categories in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) categories = await self.store.get_group_categories(group_id=group_id) return {"categories": categories} - async def get_group_category(self, group_id, requester_user_id, category_id): - """Get a specific category in a group (as seen by user) - """ + async def get_group_category( + self, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: + """Get a specific category in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = await self.store.get_group_category( + return await self.store.get_group_category( group_id=group_id, category_id=category_id ) - logger.info("group %s", res) - - return res - - async def get_group_roles(self, group_id, requester_user_id): - """Get all roles in a group (as seen by user) - """ + async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict: + """Get all roles in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) roles = await self.store.get_group_roles(group_id=group_id) return {"roles": roles} - async def get_group_role(self, group_id, requester_user_id, role_id): - """Get a specific role in a group (as seen by user) - """ + async def get_group_role( + self, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: + """Get a specific role in a group (as seen by user)""" await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - res = await self.store.get_group_role(group_id=group_id, role_id=role_id) - return res + return await self.store.get_group_role(group_id=group_id, role_id=role_id) - async def get_group_profile(self, group_id, requester_user_id): - """Get the group profile as seen by requester_user_id - """ + async def get_group_profile( + self, group_id: str, requester_user_id: str + ) -> JsonDict: + """Get the group profile as seen by requester_user_id""" await self.check_group_is_ours(group_id, requester_user_id) @@ -218,7 +236,9 @@ class GroupsServerWorkerHandler: else: raise SynapseError(404, "Unknown group") - async def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the users in group as seen by requester_user_id. The ordering is arbitrary at the moment @@ -267,7 +287,9 @@ class GroupsServerWorkerHandler: return {"chunk": chunk, "total_user_count_estimate": len(user_results)} - async def get_invited_users_in_group(self, group_id, requester_user_id): + async def get_invited_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the users that have been invited to a group as seen by requester_user_id. The ordering is arbitrary at the moment @@ -297,7 +319,9 @@ class GroupsServerWorkerHandler: return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} - async def get_rooms_in_group(self, group_id, requester_user_id): + async def get_rooms_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the rooms in group as seen by requester_user_id This returns rooms in order of decreasing number of joined users @@ -335,17 +359,21 @@ class GroupsServerWorkerHandler: class GroupsServerHandler(GroupsServerWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # Ensure attestations get renewed hs.get_groups_attestation_renewer() async def update_group_summary_room( - self, group_id, requester_user_id, room_id, category_id, content - ): - """Add/update a room to the group summary - """ + self, + group_id: str, + requester_user_id: str, + room_id: str, + category_id: str, + content: JsonDict, + ) -> JsonDict: + """Add/update a room to the group summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -367,10 +395,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} async def delete_group_summary_room( - self, group_id, requester_user_id, room_id, category_id - ): - """Remove a room from the summary - """ + self, group_id: str, requester_user_id: str, room_id: str, category_id: str + ) -> JsonDict: + """Remove a room from the summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -381,7 +408,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def set_group_join_policy(self, group_id, requester_user_id, content): + async def set_group_join_policy( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Sets the group join policy. Currently supported policies are: @@ -401,10 +430,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} async def update_group_category( - self, group_id, requester_user_id, category_id, content - ): - """Add/Update a group category - """ + self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict + ) -> JsonDict: + """Add/Update a group category""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -421,9 +449,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def delete_group_category(self, group_id, requester_user_id, category_id): - """Delete a group category - """ + async def delete_group_category( + self, group_id: str, requester_user_id: str, category_id: str + ) -> JsonDict: + """Delete a group category""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -434,9 +463,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def update_group_role(self, group_id, requester_user_id, role_id, content): - """Add/update a role in a group - """ + async def update_group_role( + self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict + ) -> JsonDict: + """Add/update a role in a group""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -451,9 +481,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def delete_group_role(self, group_id, requester_user_id, role_id): - """Remove role from group - """ + async def delete_group_role( + self, group_id: str, requester_user_id: str, role_id: str + ) -> JsonDict: + """Remove role from group""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -463,10 +494,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} async def update_group_summary_user( - self, group_id, requester_user_id, user_id, role_id, content - ): - """Add/update a users entry in the group summary - """ + self, + group_id: str, + requester_user_id: str, + user_id: str, + role_id: str, + content: JsonDict, + ) -> JsonDict: + """Add/update a users entry in the group summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -486,10 +521,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} async def delete_group_summary_user( - self, group_id, requester_user_id, user_id, role_id - ): - """Remove a user from the group summary - """ + self, group_id: str, requester_user_id: str, user_id: str, role_id: str + ) -> JsonDict: + """Remove a user from the group summary""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -500,26 +534,43 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def update_group_profile(self, group_id, requester_user_id, content): - """Update the group profile - """ + async def update_group_profile( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> None: + """Update the group profile""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) profile = {} - for keyname in ("name", "avatar_url", "short_description", "long_description"): + for keyname, max_length in ( + ("name", MAX_DISPLAYNAME_LEN), + ("avatar_url", MAX_AVATAR_URL_LEN), + ("short_description", MAX_SHORT_DESC_LEN), + ("long_description", MAX_LONG_DESC_LEN), + ): if keyname in content: value = content[keyname] if not isinstance(value, str): - raise SynapseError(400, "%r value is not a string" % (keyname,)) + raise SynapseError( + 400, + "%r value is not a string" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) + if len(value) > max_length: + raise SynapseError( + 400, + "Invalid %s parameter" % (keyname,), + errcode=Codes.INVALID_PARAM, + ) profile[keyname] = value await self.store.update_group_profile(group_id, profile) - async def add_room_to_group(self, group_id, requester_user_id, room_id, content): - """Add room to group - """ + async def add_room_to_group( + self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict + ) -> JsonDict: + """Add room to group""" RoomID.from_string(room_id) # Ensure valid room id await self.check_group_is_ours( @@ -533,10 +584,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} async def update_room_in_group( - self, group_id, requester_user_id, room_id, config_key, content - ): - """Update room in group - """ + self, + group_id: str, + requester_user_id: str, + room_id: str, + config_key: str, + content: JsonDict, + ) -> JsonDict: + """Update room in group""" RoomID.from_string(room_id) # Ensure valid room id await self.check_group_is_ours( @@ -554,9 +609,10 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def remove_room_from_group(self, group_id, requester_user_id, room_id): - """Remove room from group - """ + async def remove_room_from_group( + self, group_id: str, requester_user_id: str, room_id: str + ) -> JsonDict: + """Remove room from group""" await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) @@ -565,13 +621,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def invite_to_group(self, group_id, user_id, requester_user_id, content): - """Invite user to group - """ + async def invite_to_group( + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: + """Invite user to group""" group = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id ) + if not group: + raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE) # TODO: Check if user knocked @@ -594,6 +653,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() + assert isinstance( + groups_local, GroupsLocalHandler + ), "Workers cannot invites users to groups." res = await groups_local.on_invite(group_id, user_id, content) local_attestation = None else: @@ -629,6 +691,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): local_attestation=local_attestation, remote_attestation=remote_attestation, ) + return {"state": "join"} elif res["state"] == "invite": await self.store.add_group_invite(group_id, user_id) return {"state": "invite"} @@ -637,13 +700,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler): else: raise SynapseError(502, "Unknown state returned by HS") - async def _add_user(self, group_id, user_id, content): + async def _add_user( + self, group_id: str, user_id: str, content: JsonDict + ) -> Optional[JsonDict]: """Add a user to a group based on a content dict. See accept_invite, join_group. """ if not self.hs.is_mine_id(user_id): - local_attestation = self.attestations.create_attestation(group_id, user_id) + local_attestation = self.attestations.create_attestation( + group_id, user_id + ) # type: Optional[JsonDict] remote_attestation = content["attestation"] @@ -667,7 +734,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return local_attestation - async def accept_invite(self, group_id, requester_user_id, content): + async def accept_invite( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """User tries to accept an invite to the group. This is different from them asking to join, and so should error if no @@ -686,7 +755,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"state": "join", "attestation": local_attestation} - async def join_group(self, group_id, requester_user_id, content): + async def join_group( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """User tries to join the group. This will error if the group requires an invite/knock to join @@ -695,6 +766,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler): group_info = await self.check_group_is_ours( group_id, requester_user_id, and_exists=True ) + if not group_info: + raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND) if group_info["join_policy"] != "open": raise SynapseError(403, "Group is not publicly joinable") @@ -702,26 +775,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"state": "join", "attestation": local_attestation} - async def knock(self, group_id, requester_user_id, content): - """A user requests becoming a member of the group - """ - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - raise NotImplementedError() - - async def accept_knock(self, group_id, requester_user_id, content): - """Accept a users knock to the room. - - Errors if the user hasn't knocked, rather than inviting them. - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - raise NotImplementedError() - async def remove_user_from_group( - self, group_id, user_id, requester_user_id, content - ): + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Remove a user from the group; either a user is leaving or an admin kicked them. """ @@ -743,6 +799,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): if is_kick: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() + assert isinstance( + groups_local, GroupsLocalHandler + ), "Workers cannot remove users from groups." await groups_local.user_removed_from_group(group_id, user_id, {}) else: await self.transport_client.remove_user_from_group_notification( @@ -759,14 +818,15 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {} - async def create_group(self, group_id, requester_user_id, content): - group = await self.check_group_is_ours(group_id, requester_user_id) - + async def create_group( + self, group_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: logger.info("Attempting to create group with ID: %r", group_id) # parsing the id into a GroupID validates it. group_id_obj = GroupID.from_string(group_id) + group = await self.check_group_is_ours(group_id, requester_user_id) if group: raise SynapseError(400, "Group already exists") @@ -811,7 +871,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): local_attestation = self.attestations.create_attestation( group_id, requester_user_id - ) + ) # type: Optional[JsonDict] else: local_attestation = None remote_attestation = None @@ -834,15 +894,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler): return {"group_id": group_id} - async def delete_group(self, group_id, requester_user_id): + async def delete_group(self, group_id: str, requester_user_id: str) -> None: """Deletes a group, kicking out all current members. Only group admins or server admins can call this request Args: - group_id (str) - request_user_id (str) - + group_id: The group ID to delete. + requester_user_id: The user requesting to delete the group. """ await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) @@ -865,6 +924,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler): async def _kick_user_from_group(user_id): if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() + assert isinstance( + groups_local, GroupsLocalHandler + ), "Workers cannot kick users from groups." await groups_local.user_removed_from_group(group_id, user_id, {}) else: await self.transport_client.remove_user_from_group_notification( @@ -896,9 +958,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler): await self.store.delete_group(group_id) -def _parse_join_policy_from_contents(content): - """Given a content for a request, return the specified join policy or None - """ +def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]: + """Given a content for a request, return the specified join policy or None""" join_policy_dict = content.get("m.join_policy") if join_policy_dict: @@ -907,9 +968,8 @@ def _parse_join_policy_from_contents(content): return None -def _parse_join_policy_dict(join_policy_dict): - """Given a dict for the "m.join_policy" config return the join policy specified - """ +def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str: + """Given a dict for the "m.join_policy" config return the join policy specified""" join_policy_type = join_policy_dict.get("type") if not join_policy_type: return "invite" @@ -919,7 +979,7 @@ def _parse_join_policy_dict(join_policy_dict): return join_policy_type -def _parse_visibility_from_contents(content): +def _parse_visibility_from_contents(content: JsonDict) -> bool: """Given a content for a request parse out whether the entity should be public or not """ @@ -933,7 +993,7 @@ def _parse_visibility_from_contents(content): return is_public -def _parse_visibility_dict(visibility): +def _parse_visibility_dict(visibility: JsonDict) -> bool: """Given a dict for the "m.visibility" config return if the entity should be public or not """ diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 664d09da1c..ce97fa70d7 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py
@@ -18,11 +18,14 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional, Tuple from synapse.api.errors import StoreError, SynapseError from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import UserID from synapse.util import stringutils @@ -40,27 +43,37 @@ class AccountValidityHandler: self.sendmail = self.hs.get_sendmail() self.clock = self.hs.get_clock() - self._account_validity = self.hs.config.account_validity + self._account_validity_enabled = self.hs.config.account_validity_enabled + self._account_validity_renew_by_email_enabled = ( + self.hs.config.account_validity_renew_by_email_enabled + ) + self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory + self.profile_handler = self.hs.get_profile_handler() + + self._account_validity_period = None + if self._account_validity_enabled: + self._account_validity_period = self.hs.config.account_validity_period if ( - self._account_validity.enabled - and self._account_validity.renew_by_email_enabled + self._account_validity_enabled + and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. self._template_html = self.config.account_validity_template_html self._template_text = self.config.account_validity_template_text + account_validity_renew_email_subject = ( + self.hs.config.account_validity_renew_email_subject + ) try: app_name = self.hs.config.email_app_name - self._subject = self._account_validity.renew_email_subject % { - "app": app_name - } + self._subject = account_validity_renew_email_subject % {"app": app_name} self._from_string = self.hs.config.email_notif_from % {"app": app_name} except Exception: # If substitution failed, fall back to the bare strings. - self._subject = self._account_validity.renew_email_subject + self._subject = account_validity_renew_email_subject self._from_string = self.hs.config.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] @@ -69,6 +82,18 @@ class AccountValidityHandler: if hs.config.run_background_tasks: self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) + # Mark users as inactive when they expired. Check once every hour + if self._account_validity_enabled: + + def mark_expired_users_as_inactive(): + # run as a background process to allow async functions to work + return run_as_background_process( + "_mark_expired_users_as_inactive", + self._mark_expired_users_as_inactive, + ) + + self.clock.looping_call(mark_expired_users_as_inactive, 60 * 60 * 1000) + @wrap_as_background_process("send_renewals") async def _send_renewal_emails(self) -> None: """Gets the list of users whose account is expiring in the amount of time @@ -221,47 +246,107 @@ class AccountValidityHandler: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - async def renew_account(self, renewal_token: str) -> bool: + async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. + If it turns out that the token is valid but has already been used, then the + token is considered stale. A token is stale if the 'token_used_ts_ms' db column + is non-null. + Args: renewal_token: Token sent with the renewal request. Returns: - Whether the provided token is valid. + A tuple containing: + * A bool representing whether the token is valid and unused. + * A bool representing whether the token is stale. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. """ try: - user_id = await self.store.get_user_from_renewal_token(renewal_token) + ( + user_id, + current_expiration_ts, + token_used_ts, + ) = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - return False + return False, False, 0 + + # Check whether this token has already been used. + if token_used_ts: + logger.info( + "User '%s' attempted to use previously used token '%s' to renew account", + user_id, + renewal_token, + ) + return False, True, current_expiration_ts logger.debug("Renewing an account for user %s", user_id) - await self.renew_account_for_user(user_id) - return True + # Renew the account. Pass the renewal_token here so that it is not cleared. + # We want to keep the token around in case the user attempts to renew their + # account with the same token twice (clicking the email link twice). + # + # In that case, the token will be accepted, but the account's expiration ts + # will remain unchanged. + new_expiration_ts = await self.renew_account_for_user( + user_id, renewal_token=renewal_token + ) + + return True, False, new_expiration_ts async def renew_account_for_user( - self, user_id: str, expiration_ts: int = None, email_sent: bool = False + self, + user_id: str, + expiration_ts: Optional[int] = None, + email_sent: bool = False, + renewal_token: Optional[str] = None, ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token: Token sent with the renewal request. + user_id: The ID of the user to renew. expiration_ts: New expiration date. Defaults to now + validity period. - email_sen: Whether an email has been sent for this validity period. - Defaults to False. + email_sent: Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. The user's token + will be cleared if this is None. Returns: New expiration date for this account, as a timestamp in milliseconds since epoch. """ + now = self.clock.time_msec() if expiration_ts is None: - expiration_ts = self.clock.time_msec() + self._account_validity.period + expiration_ts = now + self._account_validity_period await self.store.set_account_validity_for_user( - user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent + user_id=user_id, + expiration_ts=expiration_ts, + email_sent=email_sent, + renewal_token=renewal_token, + token_used_ts=now, ) + # Check if renewed users should be reintroduced to the user directory + if self._show_users_in_user_directory: + # Show the user in the directory again by setting them to active + await self.profile_handler.set_active( + [UserID.from_string(user_id)], True, True + ) + return expiration_ts + + async def _mark_expired_users_as_inactive(self): + """Iterate over active, expired users. Mark them as inactive in order to hide them + from the user directory. + + Returns: + Deferred + """ + # Get active, expired users + active_expired_users = await self.store.get_expired_users() + + # Mark each as non-active + await self.profile_handler.set_active(active_expired_users, False, True) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 37e63da9b1..db68c94c50 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py
@@ -203,13 +203,11 @@ class AdminHandler(BaseHandler): class ExfiltrationWriter(metaclass=abc.ABCMeta): - """Interface used to specify how to write exported data. - """ + """Interface used to specify how to write exported data.""" @abc.abstractmethod def write_events(self, room_id: str, events: List[EventBase]) -> None: - """Write a batch of events for a room. - """ + """Write a batch of events for a room.""" raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5c6458eb52..deab8ff2d0 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -290,7 +290,9 @@ class ApplicationServicesHandler: if not interested: continue presence_events, _ = await presence_source.get_new_events( - user=user, service=service, from_key=from_key, + user=user, + service=service, + from_key=from_key, ) time_now = self.clock.time_msec() events.extend( diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a19c556437..9ba9f591d9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier( # Ensure the identifier has a type if "type" not in identifier: raise SynapseError( - 400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM, + 400, + "'identifier' dict has no key 'type'", + errcode=Codes.MISSING_PARAM, ) return identifier @@ -351,7 +353,11 @@ class AuthHandler(BaseHandler): try: result, params, session_id = await self.check_ui_auth( - flows, request, request_body, description, get_new_session_data, + flows, + request, + request_body, + description, + get_new_session_data, ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). @@ -379,8 +385,7 @@ class AuthHandler(BaseHandler): return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: - """Get a list of the authentication types this user can use - """ + """Get a list of the authentication types this user can use""" ui_auth_types = set() @@ -723,7 +728,9 @@ class AuthHandler(BaseHandler): } def _auth_dict_for_flows( - self, flows: List[List[str]], session_id: str, + self, + flows: List[List[str]], + session_id: str, ) -> Dict[str, Any]: public_flows = [] for f in flows: @@ -880,7 +887,9 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, login_submission: Dict[str, Any], ratelimit: bool = False, + self, + login_submission: Dict[str, Any], + ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Authenticates the user for the /login API @@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler): raise async def _validate_userid_login( - self, username: str, login_submission: Dict[str, Any], + self, + username: str, + login_submission: Dict[str, Any], ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]: """Helper for validate_login @@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler): # is considered OK since the newest SSO attributes should be most valid. if extra_attributes: self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( - self._clock.time_msec(), extra_attributes, + self._clock.time_msec(), + extra_attributes, ) # Create a login token @@ -1472,10 +1484,22 @@ class AuthHandler(BaseHandler): # Remove the query parameters from the redirect URL to get a shorter version of # it. This is only to display a human-readable URL in the template, but not the # URL we redirect users to. - redirect_url_no_params = client_redirect_url.split("?")[0] + url_parts = urllib.parse.urlsplit(client_redirect_url) + + if url_parts.scheme == "https": + # for an https uri, just show the netloc (ie, the hostname. Specifically, + # the bit between "//" and "/"; this includes any potential + # "username:password@" prefix.) + display_url = url_parts.netloc + else: + # for other uris, strip the query-params (including the login token) and + # fragment. + display_url = urllib.parse.urlunsplit( + (url_parts.scheme, url_parts.netloc, url_parts.path, "", "") + ) html = self._sso_redirect_confirm_template.render( - display_url=redirect_url_no_params, + display_url=display_url, redirect_url=redirect_url, server_name=self._server_name, new_user=new_user, @@ -1690,5 +1714,9 @@ class PasswordProvider: # This might return an awaitable, if it does block the log out # until it completes. await maybe_awaitable( - g(user_id=user_id, device_id=device_id, access_token=access_token,) + g( + user_id=user_id, + device_id=device_id, + access_token=access_token, + ) ) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index bd35d1fb87..04972f9cf0 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from xml.etree import ElementTree as ET import attr @@ -33,8 +33,7 @@ logger = logging.getLogger(__name__) class CasError(Exception): - """Used to catch errors when validating the CAS ticket. - """ + """Used to catch errors when validating the CAS ticket.""" def __init__(self, error, error_description=None): self.error = error @@ -49,7 +48,7 @@ class CasError(Exception): @attr.s(slots=True, frozen=True) class CasResponse: username = attr.ib(type=str) - attributes = attr.ib(type=Dict[str, Optional[str]]) + attributes = attr.ib(type=Dict[str, List[Optional[str]]]) class CasHandler: @@ -100,7 +99,10 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) + return "%s?%s" % ( + self._cas_service_url, + urllib.parse.urlencode(args), + ) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] @@ -169,7 +171,7 @@ class CasHandler: # Iterate through the nodes and pull out the user and any extra attributes. user = None - attributes = {} + attributes = {} # type: Dict[str, List[Optional[str]]] for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -182,7 +184,7 @@ class CasHandler: tag = attribute.tag if "}" in tag: tag = tag.split("}")[1] - attributes[tag] = attribute.text + attributes.setdefault(tag, []).append(attribute.text) # Ensure a user was found. if user is None: @@ -296,36 +298,20 @@ class CasHandler: # first check if we're doing a UIA if session: return await self._sso_handler.complete_sso_ui_auth_request( - self.idp_id, cas_response.username, session, request, + self.idp_id, + cas_response.username, + session, + request, ) # otherwise, we're handling a login request. # Ensure that the attributes of the logged in user meet the required # attributes. - for required_attribute, required_value in self._cas_required_attributes.items(): - # If required attribute was not in CAS Response - Forbidden - if required_attribute not in cas_response.attributes: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return - - # Also need to check value - if required_value is not None: - actual_value = cas_response.attributes[required_attribute] - # If required attribute value does not match expected - Forbidden - if required_value != actual_value: - self._sso_handler.render_error( - request, - "unauthorised", - "You are not authorised to log in here.", - 401, - ) - return + if not self._sso_handler.check_required_attributes( + request, cas_response.attributes, self._cas_required_attributes + ): + return # Call the mapper to register/login the user @@ -372,9 +358,10 @@ class CasHandler: if failures: raise RuntimeError("CAS is not expected to de-duplicate Matrix IDs") + # Arbitrarily use the first attribute found. display_name = cas_response.attributes.get( - self._cas_displayname_attribute, None - ) + self._cas_displayname_attribute, [None] + )[0] return UserAttributes(localpart=localpart, display_name=display_name) @@ -384,7 +371,8 @@ class CasHandler: user_id = UserID(localpart, self._hostname).to_string() logger.debug( - "Looking for existing account based on mapped %s", user_id, + "Looking for existing account based on mapped %s", + user_id, ) users = await self._store.get_users_by_id_case_insensitive(user_id) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index c4a3b26a84..7911d126f5 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py
@@ -50,7 +50,7 @@ class DeactivateAccountHandler(BaseHandler): if hs.config.run_background_tasks: hs.get_reactor().callWhenRunning(self._start_user_parting) - self._account_validity_enabled = hs.config.account_validity.enabled + self._account_validity_enabled = hs.config.account_validity_enabled async def deactivate_account( self, @@ -120,6 +120,9 @@ class DeactivateAccountHandler(BaseHandler): await self.store.user_set_password_hash(user_id, None) + user = UserID.from_string(user_id) + await self._profile_handler.set_active([user], False, False) + # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) await self.store.add_user_pending_deactivation(user_id) @@ -196,8 +199,7 @@ class DeactivateAccountHandler(BaseHandler): run_as_background_process("user_parter_loop", self._user_parter_loop) async def _user_parter_loop(self) -> None: - """Loop that parts deactivated users from rooms - """ + """Loop that parts deactivated users from rooms""" self._user_parter_running = True logger.info("Starting user parter") try: @@ -214,8 +216,7 @@ class DeactivateAccountHandler(BaseHandler): self._user_parter_running = False async def _part_user(self, user_id: str) -> None: - """Causes the given user_id to leave all the rooms they're joined to - """ + """Causes the given user_id to leave all the rooms they're joined to""" user = UserID.from_string(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 0863154f7a..df3cdc8fba 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py
@@ -86,7 +86,7 @@ class DeviceWorkerHandler(BaseHandler): @trace async def get_device(self, user_id: str, device_id: str) -> JsonDict: - """ Retrieve the given device + """Retrieve the given device Args: user_id: The user to get the device from @@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler): @trace async def delete_device(self, user_id: str, device_id: str) -> None: - """ Delete the given device + """Delete the given device Args: user_id: The user to delete the device from. @@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler): await self.delete_devices(user_id, device_ids) async def delete_devices(self, user_id: str, device_ids: List[str]) -> None: - """ Delete several devices + """Delete several devices Args: user_id: The user to delete devices from. @@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler): await self.notify_device_update(user_id, device_ids) async def update_device(self, user_id: str, device_id: str, content: dict) -> None: - """ Update the given device + """Update the given device Args: user_id: The user to update devices of. @@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler): device id of the dehydrated device """ device_id = await self.check_device_registered( - user_id, None, initial_device_display_name, + user_id, + None, + initial_device_display_name, ) old_device_id = await self.store.store_dehydrated_device( user_id, device_id, device_data @@ -803,7 +805,8 @@ class DeviceListUpdater: try: # Try to resync the current user's devices list. result = await self.user_device_resync( - user_id=user_id, mark_failed_as_stale=False, + user_id=user_id, + mark_failed_as_stale=False, ) # user_device_resync only returns a result if it managed to @@ -813,14 +816,17 @@ class DeviceListUpdater: # self.store.update_remote_device_list_cache). if result: logger.debug( - "Successfully resynced the device list for %s", user_id, + "Successfully resynced the device list for %s", + user_id, ) except Exception as e: # If there was an issue resyncing this user, e.g. if the remote # server sent a malformed result, just log the error instead of # aborting all the subsequent resyncs. logger.debug( - "Could not resync the device list for %s: %s", user_id, e, + "Could not resync the device list for %s: %s", + user_id, + e, ) finally: # Allow future calls to retry resyncinc out of sync device lists. @@ -855,7 +861,9 @@ class DeviceListUpdater: return None except (RequestSendFailed, HttpResponseException) as e: logger.warning( - "Failed to handle device list update for %s: %s", user_id, e, + "Failed to handle device list update for %s: %s", + user_id, + e, ) if mark_failed_as_stale: @@ -931,7 +939,9 @@ class DeviceListUpdater: # Handle cross-signing keys. cross_signing_device_ids = await self.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + user_id, + master_key, + self_signing_key, ) device_ids = device_ids + cross_signing_device_ids diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0c7737e09d..1aa7d803b5 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -62,7 +62,8 @@ class DeviceMessageHandler: ) else: hs.get_federation_registry().register_instances_for_edu( - "m.direct_to_device", hs.config.worker.writers.to_device, + "m.direct_to_device", + hs.config.worker.writers.to_device, ) # The handler to call when we think a user's device list might be out of @@ -73,8 +74,8 @@ class DeviceMessageHandler: hs.get_device_handler().device_list_updater.user_device_resync ) else: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8f3a6b35a4..9a946a3cfe 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -61,8 +61,8 @@ class E2eKeysHandler: self._is_master = hs.config.worker_app is None if not self._is_master: - self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync_client = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) else: # Only register this edu handler on master as it requires writing @@ -85,7 +85,7 @@ class E2eKeysHandler: async def query_devices( self, query_body: JsonDict, timeout: int, from_user_id: str ) -> JsonDict: - """ Handle a device key query from a client + """Handle a device key query from a client { "device_keys": { @@ -391,8 +391,7 @@ class E2eKeysHandler: async def on_federation_query_client_keys( self, query_body: Dict[str, Dict[str, Optional[List[str]]]] ) -> JsonDict: - """ Handle a device key query from a federated server - """ + """Handle a device key query from a federated server""" device_keys_query = query_body.get( "device_keys", {} ) # type: Dict[str, Optional[List[str]]] @@ -1065,7 +1064,9 @@ class E2eKeysHandler: return key, key_id, verify_key async def _retrieve_cross_signing_keys_for_remote_user( - self, user: UserID, desired_key_type: str, + self, + user: UserID, + desired_key_type: str, ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database @@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: @attr.s(slots=True) class SignatureListItem: - """An item in the signature list as used by upload_signatures_for_device_keys. - """ + """An item in the signature list as used by upload_signatures_for_device_keys.""" signing_key_id = attr.ib(type=str) target_user_id = attr.ib(type=str) @@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater: logger.info("pending updates: %r", pending_updates) for master_key, self_signing_key in pending_updates: - new_device_ids = await device_list_updater.process_cross_signing_key_update( - user_id, master_key, self_signing_key, + new_device_ids = ( + await device_list_updater.process_cross_signing_key_update( + user_id, + master_key, + self_signing_key, + ) ) device_ids = device_ids + new_device_ids diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 539b4fc32e..3e23f82cf7 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py
@@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler): room_id: Optional[str] = None, is_guest: bool = False, ) -> JsonDict: - """Fetches the events stream for a given user. - """ + """Fetches the events stream for a given user.""" if room_id: blocked = await self.store.is_room_blocked(room_id) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eddc7582d0..51bdf97920 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py
@@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -111,13 +112,13 @@ class _NewEventInfo: class FederationHandler(BaseHandler): """Handles events that originated from federation. - Responsible for: - a) handling received Pdus before handing them on as Events to the rest - of the homeserver (including auth and state conflict resolutions) - b) converting events that were produced by local clients that may need - to be sent to remote homeservers. - c) doing the necessary dances to invite remote users and join remote - rooms. + Responsible for: + a) handling received Pdus before handing them on as Events to the rest + of the homeserver (including auth and state conflict resolutions) + b) converting events that were produced by local clients that may need + to be sent to remote homeservers. + c) doing the necessary dances to invite remote users and join remote + rooms. """ def __init__(self, hs: "HomeServer"): @@ -150,11 +151,11 @@ class FederationHandler(BaseHandler): ) if hs.config.worker_app: - self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client( - hs + self._user_device_resync = ( + ReplicationUserDevicesResyncRestServlet.make_client(hs) ) - self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client( - hs + self._maybe_store_room_on_outlier_membership = ( + ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs) ) else: self._device_list_updater = hs.get_device_handler().device_list_updater @@ -172,7 +173,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: - """ Process a PDU received via a federation /send/ transaction, or + """Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events Args: @@ -186,7 +187,7 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info("handling received PDU: %s", pdu) + logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) # We reprocess pdus when we have seen them only as outliers existing = await self.store.get_event( @@ -301,6 +302,14 @@ class FederationHandler(BaseHandler): room_id, event_id, ) + elif missing_prevs: + logger.info( + "[%s %s] Not recursively fetching %d missing prev_events: %s", + room_id, + event_id, + len(missing_prevs), + shortstr(missing_prevs), + ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -345,12 +354,6 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) - logger.info( - "Event %s is missing prev_events: calculating state for a " - "backwards extremity", - event_id, - ) - # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. event_map = {event_id: pdu} @@ -368,7 +371,8 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "Requesting state at missing prev_event %s", event_id, + "Requesting state at missing prev_event %s", + event_id, ) with nested_logging_context(p): @@ -388,12 +392,14 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x room_version = await self.store.get_room_version_id(room_id) - state_map = await self._state_resolution_handler.resolve_events_with_store( - room_id, - room_version, - state_maps, - event_map, - state_res_store=StateResolutionStore(self.store), + state_map = ( + await self._state_resolution_handler.resolve_events_with_store( + room_id, + room_version, + state_maps, + event_map, + state_res_store=StateResolutionStore(self.store), + ) ) # We need to give _process_received_pdu the actual state events @@ -402,9 +408,7 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. evs = await self.store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.AS_IS, + list(state_map.values()), get_prev_content=False, ) event_map.update(evs) @@ -687,9 +691,12 @@ class FederationHandler(BaseHandler): return fetched_events async def _process_received_pdu( - self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + self, + origin: str, + event: EventBase, + state: Optional[Iterable[EventBase]], ): - """ Called when we have a new pdu. We need to do auth checks and put it + """Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. Args: @@ -801,7 +808,7 @@ class FederationHandler(BaseHandler): @log_function async def backfill(self, dest, room_id, limit, extremities): - """ Trigger a backfill request to `dest` for the given `room_id` + """Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side has no new events to offer, this will return an empty list. @@ -1204,11 +1211,16 @@ class FederationHandler(BaseHandler): with nested_logging_context(event_id): try: event = await self.federation_client.get_pdu( - [destination], event_id, room_version, outlier=True, + [destination], + event_id, + room_version, + outlier=True, ) if event is None: logger.warning( - "Server %s didn't return event %s", destination, event_id, + "Server %s didn't return event %s", + destination, + event_id, ) return @@ -1235,7 +1247,8 @@ class FederationHandler(BaseHandler): if aid not in event_map ] persisted_events = await self.store.get_events( - auth_events, allow_rejected=True, + auth_events, + allow_rejected=True, ) event_infos = [] @@ -1251,7 +1264,9 @@ class FederationHandler(BaseHandler): event_infos.append(_NewEventInfo(event, None, auth)) await self._handle_new_events( - destination, room_id, event_infos, + destination, + room_id, + event_infos, ) def _sanity_check_event(self, ev): @@ -1287,7 +1302,7 @@ class FederationHandler(BaseHandler): raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def send_invite(self, target_host, event): - """ Sends the invite to the remote server for signing. + """Sends the invite to the remote server for signing. Invites must be signed by the invitee's server before distribution. """ @@ -1310,7 +1325,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict ) -> Tuple[str, int]: - """ Attempts to join the `joinee` to the room `room_id` via the + """Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. This first triggers a /make_join/ request that returns a partial @@ -1354,8 +1369,6 @@ class FederationHandler(BaseHandler): await self._clean_room_for_join(room_id) - handled_events = set() - try: # Try the host we successfully got a response to /make_join/ # request first. @@ -1375,10 +1388,6 @@ class FederationHandler(BaseHandler): auth_chain = ret["auth_chain"] auth_chain.sort(key=lambda e: e.depth) - handled_events.update([s.event_id for s in state]) - handled_events.update([a.event_id for a in auth_chain]) - handled_events.add(event.event_id) - logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join state: %s", state) @@ -1394,7 +1403,8 @@ class FederationHandler(BaseHandler): # so we can rely on it now. # await self.store.upsert_room_on_join( - room_id=room_id, room_version=room_version_obj, + room_id=room_id, + room_version=room_version_obj, ) max_stream_id = await self._persist_auth_tree( @@ -1439,6 +1449,73 @@ class FederationHandler(BaseHandler): run_in_background(self._handle_queued_pdus, room_queue) + @log_function + async def do_knock( + self, target_hosts: List[str], room_id: str, knockee: str, content: JsonDict, + ) -> Tuple[str, int]: + """Sends the knock to the remote server. + + This first triggers a make_knock request that returns a partial + event that we can fill out and sign. This is then sent to the + remote server via send_knock. + + Knock events must be signed by the knockee's server before distributing. + + Args: + target_hosts: A list of hosts that we want to try knocking through. + room_id: The ID of the room to knock on. + knockee: The ID of the user who is knocking. + content: The content of the knock event. + + Returns: + A tuple of (event ID, stream ID). + + Raises: + SynapseError: If the chosen remote server returns a 3xx/4xx code. + RuntimeError: If no servers were reachable. + """ + logger.debug("Knocking on room %s on behalf of user %s", room_id, knockee) + + # Inform the remote server of the room versions we support + supported_room_versions = list(KNOWN_ROOM_VERSIONS.keys()) + + # Ask the remote server to create a valid knock event for us. Once received, + # we sign the event + params = {"ver": supported_room_versions} # type: Dict[str, Iterable[str]] + origin, event, event_format_version = await self._make_and_verify_event( + target_hosts, room_id, knockee, Membership.KNOCK, content, params=params + ) + + # Record the room ID and its version so that we have a record of the room + await self._maybe_store_room_on_outlier_membership( + room_id=event.room_id, room_version=event_format_version + ) + + # Initially try the host that we successfully called /make_knock on + try: + target_hosts.remove(origin) + target_hosts.insert(0, origin) + except ValueError: + pass + + # Send the signed event back to the room, and potentially receive some + # further information about the room in the form of partial state events + stripped_room_state = await self.federation_client.send_knock( + target_hosts, event + ) + + # Store any stripped room state events in the "unsigned" key of the event. + # This is a bit of a hack and is cribbing off of invites. Basically we + # store the room state here and retrieve it again when this event appears + # in the invitee's sync stream. It is stripped out for all other local users. + event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] + + context = await self.state_handler.compute_event_context(event) + stream_id = await self.persist_events_and_notify( + event.room_id, [(event, context)] + ) + return event.event_id, stream_id + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. @@ -1464,7 +1541,7 @@ class FederationHandler(BaseHandler): async def on_make_join_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_join/ request, so we create a partial + """We've received a /make_join/ request, so we create a partial join event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1489,7 +1566,8 @@ class FederationHandler(BaseHandler): is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( - "Got /make_join request for room %s we are no longer in", room_id, + "Got /make_join request for room %s we are no longer in", + room_id, ) raise NotFoundError("Not an active room on this server") @@ -1523,7 +1601,7 @@ class FederationHandler(BaseHandler): return event async def on_send_join_request(self, origin, pdu): - """ We have received a join event for a room. Fully process it and + """We have received a join event for a room. Fully process it and respond with the current state and auth chains. """ event = pdu @@ -1579,7 +1657,7 @@ class FederationHandler(BaseHandler): async def on_invite_request( self, origin: str, event: EventBase, room_version: RoomVersion ): - """ We've got an invite event. Process and persist it. Sign it. + """We've got an invite event. Process and persist it. Sign it. Respond with the now signed event. """ @@ -1593,8 +1671,15 @@ class FederationHandler(BaseHandler): if self.hs.config.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") + is_published = await self.store.is_room_published(event.room_id) + if not await self.spam_checker.user_may_invite( - event.sender, event.state_key, event.room_id + event.sender, + event.state_key, + None, + room_id=event.room_id, + new_room=False, + published_room=is_published, ): raise SynapseError( 403, "This user is not permitted to send invites to this server/user" @@ -1706,7 +1791,7 @@ class FederationHandler(BaseHandler): async def on_make_leave_request( self, origin: str, room_id: str, user_id: str ) -> EventBase: - """ We've received a /make_leave/ request, so we create a partial + """We've received a /make_leave/ request, so we create a partial leave event for the room and return that. We do *not* persist or process it until the other server has signed it and sent it back. @@ -1781,9 +1866,122 @@ class FederationHandler(BaseHandler): return None - async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: - """Returns the state at the event. i.e. not including said event. + @log_function + async def on_make_knock_request( + self, origin: str, room_id: str, user_id: str + ) -> EventBase: + """We've received a make_knock request, so we create a partial + knock event for the room and return that. We do *not* persist or + process it until the other server has signed it and sent it back. + + Args: + origin: The (verified) server name of the requesting server. + room_id: The room to create the knock event in. + user_id: The user to create the knock for. + + Returns: + The partial knock event. """ + if get_domain_from_id(user_id) != origin: + logger.info( + "Get /xyz.amorgan.knock/make_knock request for user %r" + "from different origin %s, ignoring", + user_id, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + room_version = await self.store.get_room_version_id(room_id) + + builder = self.event_builder_factory.new( + room_version, + { + "type": EventTypes.Member, + "content": {"membership": Membership.KNOCK}, + "room_id": room_id, + "sender": user_id, + "state_key": user_id, + }, + ) + + event, context = await self.event_creation_handler.create_new_client_event( + builder=builder + ) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.warning("Creation of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_knock_request` + await self.auth.check_from_context( + room_version, event, context, do_sig_check=False + ) + except AuthError as e: + logger.warning("Failed to create new knock %r because %s", event, e) + raise e + + return event + + @log_function + async def on_send_knock_request( + self, origin: str, event: EventBase + ) -> EventContext: + """ + We have received a knock event for a room. Verify that event and send it into the room + on the knocking homeserver's behalf. + + Args: + origin: The remote homeserver of the knocking user. + event: The knocking member event that has been signed by the remote homeserver. + + Returns: + The context of the event after inserting it into the room graph. + """ + logger.debug( + "on_send_knock_request: Got event: %s, signatures: %s", + event.event_id, + event.signatures, + ) + + if get_domain_from_id(event.sender) != origin: + logger.info( + "Got /xyz.amorgan.knock/send_knock request for user %r " + "from different origin %s", + event.sender, + origin, + ) + raise SynapseError(403, "User not from origin", Codes.FORBIDDEN) + + event.internal_metadata.outlier = False + + context = await self._handle_new_event(origin, event) + + event_allowed = await self.third_party_event_rules.check_event_allowed( + event, context + ) + if not event_allowed: + logger.info("Sending of knock %s forbidden by third-party rules", event) + raise SynapseError( + 403, "This event is not allowed in this context", Codes.FORBIDDEN + ) + + logger.debug( + "on_send_knock_request: After _handle_new_event: %s, sigs: %s", + event.event_id, + event.signatures, + ) + + return context + + async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]: + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) @@ -1809,8 +2007,7 @@ class FederationHandler(BaseHandler): return [] async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: - """Returns the state at the event. i.e. not including said event. - """ + """Returns the state at the event. i.e. not including said event.""" event = await self.store.get_event(event_id, check_room_id=room_id) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) @@ -2016,7 +2213,11 @@ class FederationHandler(BaseHandler): for e_id in missing_auth_events: m_ev = await self.federation_client.get_pdu( - [origin], e_id, room_version=room_version, outlier=True, timeout=10000, + [origin], + e_id, + room_version=room_version, + outlier=True, + timeout=10000, ) if m_ev and m_ev.event_id == e_id: event_map[e_id] = m_ev @@ -2166,7 +2367,9 @@ class FederationHandler(BaseHandler): ) logger.debug( - "Doing soft-fail check for %s: state %s", event.event_id, current_state_ids, + "Doing soft-fail check for %s: state %s", + event.event_id, + current_state_ids, ) # Now check if event pass auth against said current state @@ -2519,7 +2722,7 @@ class FederationHandler(BaseHandler): async def construct_auth_difference( self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase] ) -> Dict: - """ Given a local and remote auth chain, find the differences. This + """Given a local and remote auth chain, find the differences. This assumes that we have already processed all events in remote_auth Params: diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 71f11ef94a..bfb95e3eee 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py
@@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler: async def get_users_in_group( self, group_id: str, requester_user_id: str ) -> JsonDict: - """Get users in a group - """ + """Get users in a group""" if self.is_mine_id(group_id): return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id @@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def create_group( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Create a group - """ + """Create a group""" logger.info("Asking to create group with ID: %r", group_id) @@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def join_group( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Request to join a group - """ + """Request to join a group""" if self.is_mine_id(group_id): await self.groups_server_handler.join_group(group_id, user_id, content) local_attestation = None @@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def accept_invite( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """Accept an invite to a group - """ + """Accept an invite to a group""" if self.is_mine_id(group_id): await self.groups_server_handler.accept_invite(group_id, user_id, content) local_attestation = None @@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def invite( self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict ) -> JsonDict: - """Invite a user to a group - """ + """Invite a user to a group""" content = {"requester_user_id": requester_user_id, "config": config} if self.is_mine_id(group_id): res = await self.groups_server_handler.invite_to_group( @@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def on_invite( self, group_id: str, user_id: str, content: JsonDict ) -> JsonDict: - """One of our users were invited to a group - """ + """One of our users were invited to a group""" # TODO: Support auto join and rejection if not self.is_mine_id(user_id): @@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def remove_user_from_group( self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict ) -> JsonDict: - """Remove a user from a group - """ + """Remove a user from a group""" if user_id == requester_user_id: token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" @@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): async def user_removed_from_group( self, group_id: str, user_id: str, content: JsonDict ) -> None: - """One of our users was removed/kicked from a group - """ + """One of our users was removed/kicked from a group""" # TODO: Check if user in group token = await self.store.register_user_group_membership( group_id, user_id, membership="leave" diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 8fc1e8b91c..ac81fa3678 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,9 +22,11 @@ import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple from synapse.api.errors import ( + AuthError, CodeMessageException, Codes, HttpResponseException, + ProxiedRequestError, SynapseError, ) from synapse.api.ratelimiting import Ratelimiter @@ -41,8 +43,6 @@ from ._base import BaseHandler logger = logging.getLogger(__name__) -id_server_scheme = "https://" - class IdentityHandler(BaseHandler): def __init__(self, hs): @@ -57,6 +57,9 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_federation_http_client() self.hs = hs + self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls + self._enable_lookup = hs.config.enable_3pid_lookup + self._web_client_location = hs.config.invite_client_location # Ratelimiters for `/requestToken` endpoints. @@ -72,7 +75,10 @@ class IdentityHandler(BaseHandler): ) def ratelimit_request_token_requests( - self, request: SynapseRequest, medium: str, address: str, + self, + request: SynapseRequest, + medium: str, + address: str, ): """Used to ratelimit requests to `/requestToken` by IP and address. @@ -86,14 +92,14 @@ class IdentityHandler(BaseHandler): self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) async def threepid_from_creds( - self, id_server: str, creds: Dict[str, str] + self, id_server_url: str, creds: Dict[str, str] ) -> Optional[JsonDict]: """ Retrieve and validate a threepid identifier from a "credentials" dictionary against a given identity server Args: - id_server: The identity server to validate 3PIDs against. Must be a + id_server_url: The identity server to validate 3PIDs against. Must be a complete URL including the protocol (http(s)://) creds: Dictionary containing the following keys: * client_secret|clientSecret: A unique secret str provided by the client @@ -118,7 +124,14 @@ class IdentityHandler(BaseHandler): query_params = {"sid": session_id, "client_secret": client_secret} - url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) + + url = "%s%s" % ( + id_server_url, + "/_matrix/identity/api/v1/3pid/getValidated3pid", + ) try: data = await self.http_client.get_json(url, query_params) @@ -127,7 +140,7 @@ class IdentityHandler(BaseHandler): except HttpResponseException as e: logger.info( "%s returned %i for threepid validation for: %s", - id_server, + id_server_url, e.code, creds, ) @@ -141,7 +154,7 @@ class IdentityHandler(BaseHandler): if "medium" in data: return data - logger.info("%s reported non-validated threepid: %s", id_server, creds) + logger.info("%s reported non-validated threepid: %s", id_server_url, creds) return None async def bind_threepid( @@ -173,14 +186,19 @@ class IdentityHandler(BaseHandler): if id_access_token is None: use_v2 = False + # if we have a rewrite rule set for the identity server, + # apply it now, but only for sending the request (not + # storing in the database). + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + # Decide which API endpoint URLs to use headers = {} bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid} if use_v2: - bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,) + bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,) headers["Authorization"] = create_id_access_token_header(id_access_token) # type: ignore else: - bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,) + bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,) try: # Use the blacklisting http client as this call is only to identity servers @@ -267,9 +285,6 @@ class IdentityHandler(BaseHandler): True on success, otherwise False if the identity server doesn't support unbinding """ - url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,) - url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") - content = { "mxid": mxid, "threepid": {"medium": threepid["medium"], "address": threepid["address"]}, @@ -278,6 +293,7 @@ class IdentityHandler(BaseHandler): # we abuse the federation http client to sign the request, but we have to send it # using the normal http client since we don't want the SRV lookup and want normal # 'browser-like' HTTPS. + url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii") auth_headers = self.federation_http_client.build_auth_headers( destination=None, method=b"POST", @@ -287,6 +303,15 @@ class IdentityHandler(BaseHandler): ) headers = {b"Authorization": auth_headers} + # if we have a rewrite rule set for the identity server, + # apply it now. + # + # Note that destination_is has to be the real id_server, not + # the server we connect to. + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,) + try: # Use the blacklisting http client as this call is only to identity servers # provided by a client @@ -401,9 +426,28 @@ class IdentityHandler(BaseHandler): return session_id + def rewrite_id_server_url(self, url: str, add_https=False) -> str: + """Given an identity server URL, optionally add a protocol scheme + before rewriting it according to the rewrite_identity_server_urls + config option + + Adds https:// to the URL if specified, then tries to rewrite the + url. Returns either the rewritten URL or the URL with optional + protocol scheme additions. + """ + rewritten_url = url + if add_https: + rewritten_url = "https://" + rewritten_url + + rewritten_url = self.rewrite_identity_server_urls.get( + rewritten_url, rewritten_url + ) + logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url) + return rewritten_url + async def requestEmailToken( self, - id_server: str, + id_server_url: str, email: str, client_secret: str, send_attempt: int, @@ -414,7 +458,7 @@ class IdentityHandler(BaseHandler): validation. Args: - id_server: The identity server to proxy to + id_server_url: The identity server to proxy to email: The email to send the message to client_secret: The unique client_secret sends by the user send_attempt: Which attempt this is @@ -428,6 +472,11 @@ class IdentityHandler(BaseHandler): "client_secret": client_secret, "send_attempt": send_attempt, } + + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) + if next_link: params["next_link"] = next_link @@ -442,7 +491,8 @@ class IdentityHandler(BaseHandler): try: data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/email/requestToken", + "%s/_matrix/identity/api/v1/validate/email/requestToken" + % (id_server_url,), params, ) return data @@ -454,7 +504,7 @@ class IdentityHandler(BaseHandler): async def requestMsisdnToken( self, - id_server: str, + id_server_url: str, country: str, phone_number: str, client_secret: str, @@ -465,7 +515,7 @@ class IdentityHandler(BaseHandler): Request an external server send an SMS message on our behalf for the purposes of threepid validation. Args: - id_server: The identity server to proxy to + id_server_url: The identity server to proxy to country: The country code of the phone number phone_number: The number to send the message to client_secret: The unique client_secret sends by the user @@ -493,9 +543,13 @@ class IdentityHandler(BaseHandler): "details and update your config file." ) + # if we have a rewrite rule set for the identity server, + # apply it now. + id_server_url = self.rewrite_id_server_url(id_server_url) try: data = await self.http_client.post_json_get_json( - id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", + "%s/_matrix/identity/api/v1/validate/msisdn/requestToken" + % (id_server_url,), params, ) except HttpResponseException as e: @@ -591,6 +645,86 @@ class IdentityHandler(BaseHandler): logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) raise SynapseError(400, "Error contacting the identity server") + # TODO: The following two methods are used for proxying IS requests using + # the CS API. They should be consolidated with those in RoomMemberHandler + # https://github.com/matrix-org/synapse-dinsic/issues/25 + + async def proxy_lookup_3pid( + self, id_server: str, medium: str, address: str + ) -> JsonDict: + """Looks up a 3pid in the passed identity server. + + Args: + id_server: The server name (including port, if required) + of the identity server to use. + medium: The type of the third party identifier (e.g. "email"). + address: The third party identifier (e.g. "foo@example.com"). + + Returns: + The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + try: + data = await self.http_client.get_json( + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), + {"medium": medium, "address": address}, + ) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %s: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + return data + + async def proxy_bulk_lookup_3pid( + self, id_server: str, threepids: List[List[str]] + ) -> JsonDict: + """Looks up given 3pids in the passed identity server. + + Args: + id_server: The server name (including port, if required) + of the identity server to use. + threepids: The third party identifiers to lookup, as + a list of 2-string sized lists ([medium, address]). + + Returns: + The result of the lookup. See + https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup + for details + """ + if not self._enable_lookup: + raise AuthError( + 403, "Looking up third-party identifiers is denied from this server" + ) + + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + + try: + data = await self.http_client.post_json_get_json( + "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,), + {"threepids": threepids}, + ) + + except HttpResponseException as e: + logger.info("Proxied lookup failed: %r", e) + raise e.to_synapse_error() + except IOError as e: + logger.info("Failed to contact %s: %s", id_server, e) + raise ProxiedRequestError(503, "Failed to contact identity server") + + return data + async def lookup_3pid( self, id_server: str, @@ -611,10 +745,13 @@ class IdentityHandler(BaseHandler): Returns: the matrix ID of the 3pid, or None if it is not recognized. """ + # Rewrite id_server URL if necessary + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + if id_access_token is not None: try: results = await self._lookup_3pid_v2( - id_server, id_access_token, medium, address + id_server_url, id_access_token, medium, address ) return results @@ -632,16 +769,17 @@ class IdentityHandler(BaseHandler): logger.warning("Error when looking up hashing details: %s", e) return None - return await self._lookup_3pid_v1(id_server, medium, address) + return await self._lookup_3pid_v1(id_server, id_server_url, medium, address) async def _lookup_3pid_v1( - self, id_server: str, medium: str, address: str + self, id_server: str, id_server_url: str, medium: str, address: str ) -> Optional[str]: """Looks up a 3pid in the passed identity server using v1 lookup. Args: id_server: The server name (including port, if required) of the identity server to use. + id_server_url: The actual, reachable domain of the id server medium: The type of the third party identifier (e.g. "email"). address: The third party identifier (e.g. "foo@example.com"). @@ -649,8 +787,8 @@ class IdentityHandler(BaseHandler): the matrix ID of the 3pid, or None if it is not recognized. """ try: - data = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), + data = await self.http_client.get_json( + "%s/_matrix/identity/api/v1/lookup" % (id_server_url,), {"medium": medium, "address": address}, ) @@ -667,13 +805,12 @@ class IdentityHandler(BaseHandler): return None async def _lookup_3pid_v2( - self, id_server: str, id_access_token: str, medium: str, address: str + self, id_server_url: str, id_access_token: str, medium: str, address: str ) -> Optional[str]: """Looks up a 3pid in the passed identity server using v2 lookup. Args: - id_server: The server name (including port, if required) - of the identity server to use. + id_server_url: The protocol scheme and domain of the id server id_access_token: The access token to authenticate to the identity server with medium: The type of the third party identifier (e.g. "email"). address: The third party identifier (e.g. "foo@example.com"). @@ -683,8 +820,8 @@ class IdentityHandler(BaseHandler): """ # Check what hashing details are supported by this identity server try: - hash_details = await self.blacklisting_http_client.get_json( - "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), + hash_details = await self.http_client.get_json( + "%s/_matrix/identity/v2/hash_details" % (id_server_url,), {"access_token": id_access_token}, ) except RequestTimedOutError: @@ -692,15 +829,14 @@ class IdentityHandler(BaseHandler): if not isinstance(hash_details, dict): logger.warning( - "Got non-dict object when checking hash details of %s%s: %s", - id_server_scheme, - id_server, + "Got non-dict object when checking hash details of %s: %s", + id_server_url, hash_details, ) raise SynapseError( 400, - "Non-dict object from %s%s during v2 hash_details request: %s" - % (id_server_scheme, id_server, hash_details), + "Non-dict object from %s during v2 hash_details request: %s" + % (id_server_url, hash_details), ) # Extract information from hash_details @@ -714,8 +850,8 @@ class IdentityHandler(BaseHandler): ): raise SynapseError( 400, - "Invalid hash details received from identity server %s%s: %s" - % (id_server_scheme, id_server, hash_details), + "Invalid hash details received from identity server %s: %s" + % (id_server_url, hash_details), ) # Check if any of the supported lookup algorithms are present @@ -737,7 +873,7 @@ class IdentityHandler(BaseHandler): else: logger.warning( "None of the provided lookup algorithms of %s are supported: %s", - id_server, + id_server_url, supported_lookup_algorithms, ) raise SynapseError( @@ -750,8 +886,8 @@ class IdentityHandler(BaseHandler): headers = {"Authorization": create_id_access_token_header(id_access_token)} try: - lookup_results = await self.blacklisting_http_client.post_json_get_json( - "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), + lookup_results = await self.http_client.post_json_get_json( + "%s/_matrix/identity/v2/lookup" % (id_server_url,), { "addresses": [lookup_value], "algorithm": lookup_algorithm, @@ -839,15 +975,17 @@ class IdentityHandler(BaseHandler): if self._web_client_location: invite_config["org.matrix.web_client_location"] = self._web_client_location + # Rewrite the identity server URL if necessary + id_server_url = self.rewrite_id_server_url(id_server, add_https=True) + # Add the identity service access token to the JSON body and use the v2 # Identity Service endpoints if id_access_token is present data = None - base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server) + base_url = "%s/_matrix/identity" % (id_server_url,) if id_access_token: - key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % ( + id_server_url, ) # Attempt a v2 lookup @@ -866,9 +1004,8 @@ class IdentityHandler(BaseHandler): raise e if data is None: - key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, - id_server, + key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_url, ) url = base_url + "/api/v1/store-invite" @@ -880,10 +1017,7 @@ class IdentityHandler(BaseHandler): raise SynapseError(500, "Timed out contacting identity server") except HttpResponseException as e: logger.warning( - "Error trying to call /store-invite on %s%s: %s", - id_server_scheme, - id_server, - e, + "Error trying to call /store-invite on %s: %s", id_server_url, e, ) if data is None: @@ -896,10 +1030,9 @@ class IdentityHandler(BaseHandler): ) except HttpResponseException as e: logger.warning( - "Error calling /store-invite on %s%s with fallback " + "Error calling /store-invite on %s with fallback " "encoding: %s", - id_server_scheme, - id_server, + id_server_url, e, ) raise e @@ -920,6 +1053,42 @@ class IdentityHandler(BaseHandler): display_name = data["display_name"] return token, public_keys, fallback_public_key, display_name + async def bind_email_using_internal_sydent_api( + self, id_server_url: str, email: str, user_id: str, + ): + """Bind an email to a fully qualified user ID using the internal API of an + instance of Sydent. + + Args: + id_server_url: The URL of the Sydent instance + email: The email address to bind + user_id: The user ID to bind the email to + + Raises: + HTTPResponseException: On a non-2xx HTTP response. + """ + # Extract the domain name from the IS URL as we store IS domains instead of URLs + id_server = urllib.parse.urlparse(id_server_url).hostname + if not id_server: + # We were unable to determine the hostname, bail out + return + + # id_server_url is assumed to have no trailing slashes + url = id_server_url + "/_matrix/identity/internal/bind" + body = { + "address": email, + "medium": "email", + "mxid": user_id, + } + + # Bind the threepid + await self.http_client.post_json_get_json(url, body) + + # Remember where we bound the threepid + await self.store.add_user_bound_threepid( + user_id=user_id, medium="email", address=email, id_server=id_server, + ) + def create_id_access_token_header(id_access_token: str) -> List[str]: """Create an Authorization header for passing to SimpleHttpClient as the header value diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index fbd8df9dcc..78c3e5a10b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler): joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN] receipt = await self.store.get_linearized_receipts_for_rooms( - joined_rooms, to_key=int(now_token.receipt_key), + joined_rooms, + to_key=int(now_token.receipt_key), ) tags_by_room = await self.store.get_tags_for_user(user_id) @@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler): self.state_handler.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: - room_end_token = RoomStreamToken(None, event.stream_ordering,) + room_end_token = RoomStreamToken( + None, + event.stream_ordering, + ) deferred_room_state = run_in_background( self.state_store.get_state_for_events, [event.event_id] ) @@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler): membership, member_event_id, ) = await self.auth.check_user_in_room_or_world_readable( - room_id, user_id, allow_departed_users=True, + room_id, + user_id, + allow_departed_users=True, ) is_peeking = member_event_id is None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a15336bf00..1aded280c7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. +# Copyrignt 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -40,6 +41,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.urls import ConsentURIBuilder +from synapse.config.api import DEFAULT_ROOM_STATE_TYPES from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.events.snapshot import EventContext @@ -65,8 +67,7 @@ logger = logging.getLogger(__name__) class MessageHandler: - """Contains some read only APIs to get state about a room - """ + """Contains some read only APIs to get state about a room""" def __init__(self, hs): self.auth = hs.get_auth() @@ -88,9 +89,13 @@ class MessageHandler: ) async def get_room_data( - self, user_id: str, room_id: str, event_type: str, state_key: str, + self, + user_id: str, + room_id: str, + event_type: str, + state_key: str, ) -> dict: - """ Get data from a room. + """Get data from a room. Args: user_id @@ -174,7 +179,10 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False, + self.storage, + user_id, + last_events, + filter_send_to_client=False, ) event = last_events[0] @@ -494,7 +502,7 @@ class EventCreationHandler: membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership in {Membership.JOIN, Membership.INVITE}: + if membership in {Membership.JOIN, Membership.INVITE, Membership.KNOCK}: # If event doesn't include a display name, add one. profile = self.profile_handler content = builder.content @@ -571,7 +579,7 @@ class EventCreationHandler: async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester ) -> bool: - """"Determine if an event to be sent is exempt from having to consent + """ "Determine if an event to be sent is exempt from having to consent to the privacy policy Args: @@ -793,9 +801,10 @@ class EventCreationHandler: """ if prev_event_ids is not None: - assert len(prev_event_ids) <= 10, ( - "Attempting to create an event with %i prev_events" - % (len(prev_event_ids),) + assert ( + len(prev_event_ids) <= 10 + ), "Attempting to create an event with %i prev_events" % ( + len(prev_event_ids), ) else: prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) @@ -821,7 +830,8 @@ class EventCreationHandler: ) if not third_party_result: logger.info( - "Event %s forbidden by third-party rules", event, + "Event %s forbidden by third-party rules", + event, ) raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN @@ -920,8 +930,8 @@ class EventCreationHandler: room_version = await self.store.get_room_version_id(event.room_id) if event.internal_metadata.is_out_of_band_membership(): - # the only sort of out-of-band-membership events we expect to see here - # are invite rejections we have generated ourselves. + # the only sort of out-of-band-membership events we expect to see here are + # invite rejections and rescinded knocks that we have generated ourselves. assert event.type == EventTypes.Member assert event.content["membership"] == Membership.LEAVE else: @@ -1167,6 +1177,13 @@ class EventCreationHandler: # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) + if event.content["membership"] == Membership.KNOCK: + event.unsigned[ + "knock_room_state" + ] = await self.store.get_stripped_room_state_from_event_context( + context, DEFAULT_ROOM_STATE_TYPES, + ) + if event.type == EventTypes.Redaction: original_event = await self.store.get_event( event.redacts, @@ -1316,7 +1333,11 @@ class EventCreationHandler: # Since this is a dummy-event it is OK if it is sent by a # shadow-banned user. await self.handle_new_client_event( - requester, event, context, ratelimit=False, ignore_shadow_ban=True, + requester, + event, + context, + ratelimit=False, + ignore_shadow_ban=True, ) return True except AuthError: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 71008ec50d..07db1e31e4 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py
@@ -41,13 +41,33 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -SESSION_COOKIE_NAME = b"oidc_session" +# we want the cookie to be returned to us even when the request is the POSTed +# result of a form on another domain, as is used with `response_mode=form_post`. +# +# Modern browsers will not do so unless we set SameSite=None; however *older* +# browsers (including all versions of Safari on iOS 12?) don't support +# SameSite=None, and interpret it as SameSite=Strict: +# https://bugs.webkit.org/show_bug.cgi?id=198181 +# +# As a rather painful workaround, we set *two* cookies, one with SameSite=None +# and one with no SameSite, in the hope that at least one of them will get +# back to us. +# +# Secure is necessary for SameSite=None (and, empirically, also breaks things +# on iOS 12.) +# +# Here we have the names of the cookies, and the options we use to set them. +_SESSION_COOKIES = [ + (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"), + (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"), +] #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and #: OpenID.Core sec 3.1.3.3. @@ -72,8 +92,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]}) class OidcHandler: - """Handles requests related to the OpenID Connect login flow. - """ + """Handles requests related to the OpenID Connect login flow.""" def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() @@ -123,7 +142,6 @@ class OidcHandler: Args: request: the incoming request from the browser. """ - # The provider might redirect with an error. # In that case, just display it as-is. if b"error" in request.args: @@ -137,8 +155,12 @@ class OidcHandler: # either the provider misbehaving or Synapse being misconfigured. # The only exception of that is "access_denied", where the user # probably cancelled the login flow. In other cases, log those errors. - if error != "access_denied": - logger.error("Error from the OIDC provider: %s %s", error, description) + logger.log( + logging.INFO if error == "access_denied" else logging.ERROR, + "Received OIDC callback with error: %s %s", + error, + description, + ) self._sso_handler.render_error(request, error, description) return @@ -146,30 +168,37 @@ class OidcHandler: # otherwise, it is presumably a successful response. see: # https://tools.ietf.org/html/rfc6749#section-4.1.2 - # Fetch the session cookie - session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes] - if session is None: - logger.info("No session cookie found") + # Fetch the session cookie. See the comments on SESSION_COOKIES for why there + # are two. + + for cookie_name, _ in _SESSION_COOKIES: + session = request.getCookie(cookie_name) # type: Optional[bytes] + if session is not None: + break + else: + logger.info("Received OIDC callback, with no session cookie") self._sso_handler.render_error( request, "missing_session", "No session cookie found" ) return - # Remove the cookie. There is a good chance that if the callback failed + # Remove the cookies. There is a good chance that if the callback failed # once, it will fail next time and the code will already be exchanged. - # Removing it early avoids spamming the provider with token requests. - request.addCookie( - SESSION_COOKIE_NAME, - b"", - path="/_synapse/oidc", - expires="Thu, Jan 01 1970 00:00:00 UTC", - httpOnly=True, - sameSite="lax", - ) + # Removing the cookies early avoids spamming the provider with token requests. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=; Expires=Thu, Jan 01 1970 00:00:00 UTC; %s" + % (cookie_name, options) + ) # Check for the state query parameter if b"state" not in request.args: - logger.info("State parameter is missing") + logger.info("Received OIDC callback, with no state parameter") self._sso_handler.render_error( request, "invalid_request", "State parameter is missing" ) @@ -183,14 +212,16 @@ class OidcHandler: session, state ) except (MacaroonDeserializationException, ValueError) as e: - logger.exception("Invalid session") + logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return except MacaroonInvalidSignatureException as e: - logger.exception("Could not verify session") + logger.exception("Could not verify session for OIDC callback") self._sso_handler.render_error(request, "mismatching_session", str(e)) return + logger.info("Received OIDC callback for IdP %s", session_data.idp_id) + oidc_provider = self._providers.get(session_data.idp_id) if not oidc_provider: logger.error("OIDC session uses unknown IdP %r", oidc_provider) @@ -210,8 +241,7 @@ class OidcHandler: class OidcError(Exception): - """Used to catch errors when calling the token_endpoint - """ + """Used to catch errors when calling the token_endpoint""" def __init__(self, error, error_description=None): self.error = error @@ -240,22 +270,27 @@ class OidcProvider: self._token_generator = token_generator + self._config = provider self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - provider.client_id, provider.client_secret, provider.client_auth_method, + provider.client_id, + provider.client_secret, + provider.client_auth_method, ) # type: ClientAuth self._client_auth_method = provider.client_auth_method - self._provider_metadata = OpenIDProviderMetadata( - issuer=provider.issuer, - authorization_endpoint=provider.authorization_endpoint, - token_endpoint=provider.token_endpoint, - userinfo_endpoint=provider.userinfo_endpoint, - jwks_uri=provider.jwks_uri, - ) # type: OpenIDProviderMetadata - self._provider_needs_discovery = provider.discover + + # cache of metadata for the identity provider (endpoint uris, mostly). This is + # loaded on-demand from the discovery endpoint (if discovery is enabled), with + # possible overrides from the config. Access via `load_metadata`. + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + # cache of JWKs used by the identity provider to sign tokens. Loaded on demand + # from the IdP's jwks_uri, if required. + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + self._user_mapping_provider = provider.user_mapping_provider_class( provider.user_mapping_provider_config ) @@ -281,7 +316,7 @@ class OidcProvider: self._sso_handler.register_identity_provider(self) - def _validate_metadata(self): + def _validate_metadata(self, m: OpenIDProviderMetadata) -> None: """Verifies the provider metadata. This checks the validity of the currently loaded provider. Not @@ -300,7 +335,6 @@ class OidcProvider: if self._skip_verification is True: return - m = self._provider_metadata m.validate_issuer() m.validate_authorization_endpoint() m.validate_token_endpoint() @@ -335,11 +369,7 @@ class OidcProvider: ) else: # If we're not using userinfo, we need a valid jwks to validate the ID token - if m.get("jwks") is None: - if m.get("jwks_uri") is not None: - m.validate_jwks_uri() - else: - raise ValueError('"jwks_uri" must be set') + m.validate_jwks_uri() @property def _uses_userinfo(self) -> bool: @@ -356,11 +386,15 @@ class OidcProvider: or self._user_profile_method == "userinfo_endpoint" ) - async def load_metadata(self) -> OpenIDProviderMetadata: - """Load and validate the provider metadata. + async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata: + """Return the provider metadata. - The values metadatas are discovered if ``oidc_config.discovery`` is - ``True`` and then cached. + If this is the first call, the metadata is built from the config and from the + metadata discovery endpoint (if enabled), and then validated. If the metadata + is successfully validated, it is then cached for future use. + + Args: + force: If true, any cached metadata is discarded to force a reload. Raises: ValueError: if something in the provider is not valid @@ -368,18 +402,41 @@ class OidcProvider: Returns: The provider's metadata. """ - # If we are using the OpenID Discovery documents, it needs to be loaded once - # FIXME: should there be a lock here? - if self._provider_needs_discovery: - url = get_well_known_url(self._provider_metadata["issuer"], external=True) + if force: + # reset the cached call to ensure we get a new result + self._provider_metadata = RetryOnExceptionCachedCall(self._load_metadata) + + return await self._provider_metadata.get() + + async def _load_metadata(self) -> OpenIDProviderMetadata: + # start out with just the issuer (unlike the other settings, discovered issuer + # takes precedence over configured issuer, because configured issuer is + # required for discovery to take place.) + # + metadata = OpenIDProviderMetadata(issuer=self._config.issuer) + + # load any data from the discovery endpoint, if enabled + if self._config.discover: + url = get_well_known_url(self._config.issuer, external=True) metadata_response = await self._http_client.get_json(url) - # TODO: maybe update the other way around to let user override some values? - self._provider_metadata.update(metadata_response) - self._provider_needs_discovery = False + metadata.update(metadata_response) + + # override any discovered data with any settings in our config + if self._config.authorization_endpoint: + metadata["authorization_endpoint"] = self._config.authorization_endpoint + + if self._config.token_endpoint: + metadata["token_endpoint"] = self._config.token_endpoint - self._validate_metadata() + if self._config.userinfo_endpoint: + metadata["userinfo_endpoint"] = self._config.userinfo_endpoint - return self._provider_metadata + if self._config.jwks_uri: + metadata["jwks_uri"] = self._config.jwks_uri + + self._validate_metadata(metadata) + + return metadata async def load_jwks(self, force: bool = False) -> JWKS: """Load the JSON Web Key Set used to sign ID tokens. @@ -409,27 +466,27 @@ class OidcProvider: ] } """ + if force: + # reset the cached call to ensure we get a new result + self._jwks = RetryOnExceptionCachedCall(self._load_jwks) + return await self._jwks.get() + + async def _load_jwks(self) -> JWKS: if self._uses_userinfo: # We're not using jwt signing, return an empty jwk set return {"keys": []} - # First check if the JWKS are loaded in the provider metadata. - # It can happen either if the provider gives its JWKS in the discovery - # document directly or if it was already loaded once. metadata = await self.load_metadata() - jwk_set = metadata.get("jwks") - if jwk_set is not None and not force: - return jwk_set - # Loading the JWKS using the `jwks_uri` metadata + # Load the JWKS using the `jwks_uri` metadata. uri = metadata.get("jwks_uri") if not uri: + # this should be unreachable: load_metadata validates that + # there is a jwks_uri in the metadata if _uses_userinfo is unset raise RuntimeError('Missing "jwks_uri" in metadata') jwk_set = await self._http_client.get_json(uri) - # Caching the JWKS in the provider's metadata - self._provider_metadata["jwks"] = jwk_set return jwk_set async def _exchange_code(self, code: str) -> Token: @@ -487,7 +544,10 @@ class OidcProvider: # We're not using the SimpleHttpClient util methods as we don't want to # check the HTTP status code and we do the body encoding ourself. response = await self._http_client.request( - method="POST", uri=uri, data=body.encode("utf-8"), headers=headers, + method="POST", + uri=uri, + data=body.encode("utf-8"), + headers=headers, ) # This is used in multiple error messages below @@ -565,6 +625,7 @@ class OidcProvider: Returns: UserInfo: an object representing the user. """ + logger.debug("Using the OAuth2 access_token to request userinfo") metadata = await self.load_metadata() resp = await self._http_client.get_json( @@ -572,6 +633,8 @@ class OidcProvider: headers={"Authorization": ["Bearer {}".format(token["access_token"])]}, ) + logger.debug("Retrieved user info from userinfo endpoint: %r", resp) + return UserInfo(resp) async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo: @@ -600,17 +663,19 @@ class OidcProvider: claims_cls = ImplicitIDToken alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"]) - jwt = JsonWebToken(alg_values) claim_options = {"iss": {"values": [metadata["issuer"]]}} + id_token = token["id_token"] + logger.debug("Attempting to decode JWT id_token %r", id_token) + # Try to decode the keys in cache first, then retry by forcing the keys # to be reloaded jwk_set = await self.load_jwks() try: claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, @@ -620,13 +685,15 @@ class OidcProvider: logger.info("Reloading JWKS after decode error") jwk_set = await self.load_jwks(force=True) # try reloading the jwks claims = jwt.decode( - token["id_token"], + id_token, key=jwk_set, claims_cls=claims_cls, claims_options=claim_options, claims_params=claims_params, ) + logger.debug("Decoded id_token JWT %r; validating", claims) + claims.validate(leeway=120) # allows 2 min of clock skew return UserInfo(claims) @@ -681,14 +748,18 @@ class OidcProvider: ui_auth_session_id=ui_auth_session_id, ), ) - request.addCookie( - SESSION_COOKIE_NAME, - cookie, - path="/_synapse/client/oidc", - max_age="3600", - httpOnly=True, - sameSite="lax", - ) + + # Set the cookies. See the comments on _SESSION_COOKIES for why there are two. + # + # we have to build the header by hand rather than calling request.addCookie + # because the latter does not support SameSite=None + # (https://twistedmatrix.com/trac/ticket/10088) + + for cookie_name, options in _SESSION_COOKIES: + request.cookies.append( + b"%s=%s; Max-Age=3600; %s" + % (cookie_name, cookie.encode("utf-8"), options) + ) metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") @@ -726,19 +797,18 @@ class OidcProvider: """ # Exchange the code with the provider try: - logger.debug("Exchanging code") + logger.debug("Exchanging OAuth2 code for a token") token = await self._exchange_code(code) except OidcError as e: - logger.exception("Could not exchange code") + logger.exception("Could not exchange OAuth2 code") self._sso_handler.render_error(request, e.error, e.error_description) return - logger.debug("Successfully obtained OAuth2 access token") + logger.debug("Successfully obtained OAuth2 token data: %r", token) # Now that we have a token, get the userinfo, either by decoding the # `id_token` or by fetching the `userinfo_endpoint`. if self._uses_userinfo: - logger.debug("Fetching userinfo") try: userinfo = await self._fetch_userinfo(token) except Exception as e: @@ -746,7 +816,6 @@ class OidcProvider: self._sso_handler.render_error(request, "fetch_error", str(e)) return else: - logger.debug("Extracting userinfo from id_token") try: userinfo = await self._parse_id_token(token, nonce=session_data.nonce) except Exception as e: @@ -939,7 +1008,9 @@ class OidcSessionTokenGenerator: A signed macaroon token with the session information. """ macaroon = pymacaroons.Macaroon( - location=self._server_name, identifier="key", key=self._macaroon_secret_key, + location=self._server_name, + identifier="key", + key=self._macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = session") diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5372753707..059064a4eb 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -197,7 +197,8 @@ class PaginationHandler: stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) r = await self.store.get_room_event_before_stream_ordering( - room_id, stream_ordering, + room_id, + stream_ordering, ) if not r: logger.warning( @@ -223,7 +224,12 @@ class PaginationHandler: # the background so that it's not blocking any other operation apart from # other purges in the same room. run_as_background_process( - "_purge_history", self._purge_history, purge_id, room_id, token, True, + "_purge_history", + self._purge_history, + purge_id, + room_id, + token, + True, ) def start_purge_history( @@ -389,7 +395,9 @@ class PaginationHandler: ) await self.hs.get_federation_handler().maybe_backfill( - room_id, curr_topo, limit=pagin_config.limit, + room_id, + curr_topo, + limit=pagin_config.limit, ) to_room_key = None diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 22d1e9d35c..fb85b19770 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -349,10 +349,13 @@ class PresenceHandler(BasePresenceHandler): [self.user_to_current_state[user_id] for user_id in unpersisted] ) - async def _update_states(self, new_states): + async def _update_states(self, new_states: Iterable[UserPresenceState]) -> None: """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state should be sent to clients/servers. + + Args: + new_states: The new user presence state updates to process. """ now = self.clock.time_msec() @@ -368,7 +371,7 @@ class PresenceHandler(BasePresenceHandler): new_states_dict = {} for new_state in new_states: new_states_dict[new_state.user_id] = new_state - new_state = new_states_dict.values() + new_states = new_states_dict.values() for new_state in new_states: user_id = new_state.user_id @@ -635,8 +638,7 @@ class PresenceHandler(BasePresenceHandler): self.external_process_last_updated_ms.pop(process_id, None) async def current_state_for_user(self, user_id): - """Get the current presence state for a user. - """ + """Get the current presence state for a user.""" res = await self.current_state_for_users([user_id]) return res[user_id] @@ -658,17 +660,6 @@ class PresenceHandler(BasePresenceHandler): self._push_to_remotes(states) - async def notify_for_states(self, state, stream_id): - parties = await get_interested_parties(self.store, [state]) - room_ids_to_states, users_to_states = parties - - self.notifier.on_new_event( - "presence_key", - stream_id, - rooms=room_ids_to_states.keys(), - users=[UserID.from_string(u) for u in users_to_states], - ) - def _push_to_remotes(self, states): """Sends state updates to remote servers. @@ -678,8 +669,7 @@ class PresenceHandler(BasePresenceHandler): self.federation.send_presence(states) async def incoming_presence(self, origin, content): - """Called when we receive a `m.presence` EDU from a remote server. - """ + """Called when we receive a `m.presence` EDU from a remote server.""" if not self._presence_enabled: return @@ -729,8 +719,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states(updates) async def set_state(self, target_user, state, ignore_status_msg=False): - """Set the presence state of the user. - """ + """Set the presence state of the user.""" status_msg = state.get("status_msg", None) presence = state["presence"] @@ -758,8 +747,7 @@ class PresenceHandler(BasePresenceHandler): await self._update_states([prev_state.copy_and_replace(**new_fields)]) async def is_visible(self, observed_user, observer_user): - """Returns whether a user can see another user's presence. - """ + """Returns whether a user can see another user's presence.""" observer_room_ids = await self.store.get_rooms_for_user( observer_user.to_string() ) @@ -953,8 +941,7 @@ class PresenceHandler(BasePresenceHandler): def should_notify(old_state, new_state): - """Decides if a presence state change should be sent to interested parties. - """ + """Decides if a presence state change should be sent to interested parties.""" if old_state == new_state: return False diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index c02b951031..b04ee5f430 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +15,11 @@ # limitations under the License. import logging import random -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional + +from signedjson.sign import sign_json + +from twisted.internet import reactor from synapse.api.errors import ( AuthError, @@ -24,7 +29,11 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.logging.context import run_in_background +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.types import ( JsonDict, Requester, @@ -54,6 +63,8 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 + PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000 + def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -64,11 +75,98 @@ class ProfileHandler(BaseHandler): self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() + + self.max_avatar_size = hs.config.max_avatar_size + self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes + self.replicate_user_profiles_to = hs.config.replicate_user_profiles_to + if hs.config.run_background_tasks: self.clock.looping_call( self._update_remote_profile_cache, self.PROFILE_UPDATE_MS ) + if len(self.hs.config.replicate_user_profiles_to) > 0: + reactor.callWhenRunning(self._do_assign_profile_replication_batches) + reactor.callWhenRunning(self._start_replicate_profiles) + # Add a looping call to replicate_profiles: this handles retries + # if the replication is unsuccessful when the user updated their + # profile. + self.clock.looping_call( + self._start_replicate_profiles, self.PROFILE_REPLICATE_INTERVAL + ) + + def _do_assign_profile_replication_batches(self): + return run_as_background_process( + "_assign_profile_replication_batches", + self._assign_profile_replication_batches, + ) + + def _start_replicate_profiles(self): + return run_as_background_process( + "_replicate_profiles", self._replicate_profiles + ) + + async def _assign_profile_replication_batches(self): + """If no profile replication has been done yet, allocate replication batch + numbers to each profile to start the replication process. + """ + logger.info("Assigning profile batch numbers...") + total = 0 + while True: + assigned = await self.store.assign_profile_batch() + total += assigned + if assigned == 0: + break + logger.info("Assigned %d profile batch numbers", total) + + async def _replicate_profiles(self): + """If any profile data has been updated and not pushed to the replication targets, + replicate it. + """ + host_batches = await self.store.get_replication_hosts() + latest_batch = await self.store.get_latest_profile_replication_batch_number() + if latest_batch is None: + latest_batch = -1 + for repl_host in self.hs.config.replicate_user_profiles_to: + if repl_host not in host_batches: + host_batches[repl_host] = -1 + try: + for i in range(host_batches[repl_host] + 1, latest_batch + 1): + await self._replicate_host_profile_batch(repl_host, i) + except Exception: + logger.exception( + "Exception while replicating to %s: aborting for now", repl_host + ) + + async def _replicate_host_profile_batch(self, host, batchnum): + logger.info("Replicating profile batch %d to %s", batchnum, host) + batch_rows = await self.store.get_profile_batch(batchnum) + batch = { + UserID(r["user_id"], self.hs.hostname).to_string(): ( + {"display_name": r["displayname"], "avatar_url": r["avatar_url"]} + if r["active"] + else None + ) + for r in batch_rows + } + + url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,) + body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname} + signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0]) + try: + await self.http_client.post_json_get_json(url, signed_body) + await self.store.update_replication_batch_for_host(host, batchnum) + logger.info( + "Successfully replicated profile batch %d to %s", batchnum, host + ) + except Exception: + # This will get retried when the looping call next comes around + logger.exception( + "Failed to replicate profile batch %d to %s", batchnum, host + ) + raise + async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) @@ -207,11 +305,20 @@ class ProfileHandler(BaseHandler): # This must be done by the target user himself. if by_admin: requester = create_requester( - target_user, authenticated_entity=requester.authenticated_entity, + target_user, + authenticated_entity=requester.authenticated_entity, + ) + + if len(self.hs.config.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None await self.store.set_profile_displayname( - target_user.localpart, displayname_to_set + target_user.localpart, displayname_to_set, new_batchnum ) if self.hs.config.user_directory_search_all_users: @@ -222,6 +329,46 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) + # start a profile replication push + run_in_background(self._replicate_profiles) + + async def set_active( + self, users: List[UserID], active: bool, hide: bool, + ): + """ + Sets the 'active' flag on a set of user profiles. If set to false, the + accounts are considered deactivated or hidden. + + If 'hide' is true, then we interpret active=False as a request to try to + hide the users rather than deactivating them. This means withholding the + profiles from replication (and mark it as inactive) rather than clearing + the profile from the HS DB. + + Note that unlike set_displayname and set_avatar_url, this does *not* + perform authorization checks! This is because the only place it's used + currently is in account deactivation where we've already done these + checks anyway. + + Args: + users: The users to modify + active: Whether to set the user to active or inactive + hide: Whether to hide the user (withold from replication). If + False and active is False, user will have their profile + erased + """ + if len(self.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + + await self.store.set_profiles_active(users, active, hide, new_batchnum) + + # start a profile replication push + run_in_background(self._replicate_profiles) + async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: @@ -290,14 +437,56 @@ class ProfileHandler(BaseHandler): if new_avatar_url == "": avatar_url_to_set = None + # Enforce a max avatar size if one is defined + if avatar_url_to_set and ( + self.max_avatar_size or self.allowed_avatar_mimetypes + ): + media_id = self._validate_and_parse_media_id_from_avatar_url( + avatar_url_to_set + ) + + # Check that this media exists locally + media_info = await self.store.get_local_media(media_id) + if not media_info: + raise SynapseError( + 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND + ) + + # Ensure avatar does not exceed max allowed avatar size + media_size = media_info["media_length"] + if self.max_avatar_size and media_size > self.max_avatar_size: + raise SynapseError( + 400, + "Avatars must be less than %s bytes in size" + % (self.max_avatar_size,), + errcode=Codes.TOO_LARGE, + ) + + # Ensure the avatar's file type is allowed + if ( + self.allowed_avatar_mimetypes + and media_info["media_type"] not in self.allowed_avatar_mimetypes + ): + raise SynapseError( + 400, "Avatar file type '%s' not allowed" % media_info["media_type"] + ) + # Same like set_displayname if by_admin: requester = create_requester( target_user, authenticated_entity=requester.authenticated_entity ) + if len(self.hs.config.replicate_user_profiles_to) > 0: + cur_batchnum = ( + await self.store.get_latest_profile_replication_batch_number() + ) + new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1 + else: + new_batchnum = None + await self.store.set_profile_avatar_url( - target_user.localpart, avatar_url_to_set + target_user.localpart, avatar_url_to_set, new_batchnum ) if self.hs.config.user_directory_search_all_users: @@ -308,6 +497,23 @@ class ProfileHandler(BaseHandler): await self._update_join_states(requester, target_user) + # start a profile replication push + run_in_background(self._replicate_profiles) + + def _validate_and_parse_media_id_from_avatar_url(self, mxc): + """Validate and parse a provided avatar url and return the local media id + + Args: + mxc (str): A mxc URL + + Returns: + str: The ID of the media + """ + avatar_pieces = mxc.split("/") + if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:": + raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc) + return avatar_pieces[-1] + async def on_profile_query(self, args: JsonDict) -> JsonDict: user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index cc21fc2284..6a6c528849 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler): ) else: hs.get_federation_registry().register_instances_for_edu( - "m.receipt", hs.config.worker.writers.receipts, + "m.receipt", + hs.config.worker.writers.receipts, ) self.clock = self.hs.get_clock() self.state = hs.get_state_handler() async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: - """Called when we receive an EDU of type m.receipt from a remote HS. - """ + """Called when we receive an EDU of type m.receipt from a remote HS.""" receipts = [] for room_id, room_values in content.items(): for receipt_type, users in room_values.items(): @@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler): await self._handle_new_receipts(receipts) async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: - """Takes a list of receipts, stores them and informs the notifier. - """ + """Takes a list of receipts, stores them and informs the notifier.""" min_batch_id = None # type: Optional[int] max_batch_id = None # type: Optional[int] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 49b085269b..553fcb5b66 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -49,6 +49,7 @@ class RegistrationHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() self.profile_handler = hs.get_profile_handler() self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() self.identity_handler = self.hs.get_identity_handler() self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() @@ -57,13 +58,15 @@ class RegistrationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() + self._show_in_user_directory = self.hs.config.show_users_in_user_directory + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( hs ) - self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client( - hs + self._post_registration_client = ( + ReplicationPostRegisterActionsServlet.make_client(hs) ) else: self.device_handler = hs.get_device_handler() @@ -77,6 +80,16 @@ class RegistrationHandler(BaseHandler): guest_access_token: Optional[str] = None, assigned_user_id: Optional[str] = None, ): + """ + + Args: + localpart (str|None): The user's localpart + guest_access_token (str|None): A guest's access token + assigned_user_id (str|None): An existing User ID for this user if pre-calculated + + Returns: + Deferred + """ if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, @@ -119,6 +132,8 @@ class RegistrationHandler(BaseHandler): raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) + + # Retrieve guest user information from provided access token user_data = await self.auth.get_user_by_access_token(guest_access_token) if ( not user_data.is_guest @@ -189,12 +204,15 @@ class RegistrationHandler(BaseHandler): self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( - threepid, localpart, user_agent_ips or [], + threepid, + localpart, + user_agent_ips or [], ) if result == RegistrationBehaviour.DENY: logger.info( - "Blocked registration of %r", localpart, + "Blocked registration of %r", + localpart, ) # We return a 429 to make it not obvious that they've been # denied. @@ -203,7 +221,8 @@ class RegistrationHandler(BaseHandler): shadow_banned = result == RegistrationBehaviour.SHADOW_BAN if shadow_banned: logger.info( - "Shadow banning registration of %r", localpart, + "Shadow banning registration of %r", + localpart, ) # do not check_auth_blocking if the call is coming through the Admin API @@ -237,6 +256,12 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + if default_display_name: + requester = create_requester(user) + await self.profile_handler.set_displayname( + user, requester, default_display_name, by_admin=True + ) + if self.hs.config.user_directory_search_all_users: profile = await self.store.get_profileinfo(localpart) await self.user_directory_handler.handle_local_profile_change( @@ -246,8 +271,6 @@ class RegistrationHandler(BaseHandler): else: # autogen a sequential user ID fail_count = 0 - # If a default display name is not given, generate one. - generate_display_name = default_display_name is None # This breaks on successful registration *or* errors after 10 failures. while True: # Fail after being unable to find a suitable ID a few times @@ -258,7 +281,7 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() self.check_user_id_not_appservice_exclusive(user_id) - if generate_display_name: + if default_display_name is None: default_display_name = localpart try: await self.register_with_store( @@ -270,6 +293,11 @@ class RegistrationHandler(BaseHandler): shadow_banned=shadow_banned, ) + requester = create_requester(user) + await self.profile_handler.set_displayname( + user, requester, default_display_name, by_admin=True + ) + # Successfully registered break except SynapseError: @@ -301,7 +329,15 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - await self._register_email_threepid(user_id, threepid_dict, None) + await self.register_email_threepid(user_id, threepid_dict, None) + + # Prevent the new user from showing up in the user directory if the server + # mandates it. + if not self._show_in_user_directory: + await self.store.add_account_data_for_user( + user_id, "im.vector.hide_profile", {"hide_profile": True} + ) + await self.profile_handler.set_active([user], False, True) return user_id @@ -369,7 +405,9 @@ class RegistrationHandler(BaseHandler): config["room_alias_name"] = room_alias.localpart info, _ = await room_creation_handler.create_room( - fake_requester, config=config, ratelimit=False, + fake_requester, + config=config, + ratelimit=False, ) # If the room does not require an invite, but another user @@ -501,7 +539,10 @@ class RegistrationHandler(BaseHandler): """ await self._auto_join_rooms(user_id) - async def appservice_register(self, user_localpart: str, as_token: str) -> str: + async def appservice_register( + self, user_localpart: str, as_token: str, password_hash: str, display_name: str + ): + # FIXME: this should be factored out and merged with normal register() user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -518,12 +559,26 @@ class RegistrationHandler(BaseHandler): self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service) + display_name = display_name or user.localpart + await self.register_with_store( user_id=user_id, - password_hash="", + password_hash=password_hash, appservice_id=service_id, - create_profile_with_displayname=user.localpart, + create_profile_with_displayname=display_name, ) + + requester = create_requester(user) + await self.profile_handler.set_displayname( + user, requester, display_name, by_admin=True + ) + + if self.hs.config.user_directory_search_all_users: + profile = await self.store.get_profileinfo(user_localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile + ) + return user_id def check_user_id_not_appservice_exclusive( @@ -552,6 +607,37 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) + async def shadow_register(self, localpart, display_name, auth_result, params): + """Invokes the current registration on another server, using + shared secret registration, passing in any auth_results from + other registration UI auth flows (e.g. validated 3pids) + Useful for setting up shadow/backup accounts on a parallel deployment. + """ + + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token), + { + # XXX: auth_result is an unspecified extension for shadow registration + "auth_result": auth_result, + # XXX: another unspecified extension for shadow registration to ensure + # that the displayname is correctly set by the masters erver + "display_name": display_name, + "username": localpart, + "password": params.get("password"), + "bind_msisdn": params.get("bind_msisdn"), + "device_id": params.get("device_id"), + "initial_device_display_name": params.get( + "initial_device_display_name" + ), + "inhibit_login": False, + "access_token": as_token, + }, + ) + def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -704,6 +790,7 @@ class RegistrationHandler(BaseHandler): if auth_result and LoginType.EMAIL_IDENTITY in auth_result: threepid = auth_result[LoginType.EMAIL_IDENTITY] + # Necessary due to auth checks prior to the threepid being # written to the db if is_threepid_reserved( @@ -711,7 +798,32 @@ class RegistrationHandler(BaseHandler): ): await self.store.upsert_monthly_active_user(user_id) - await self._register_email_threepid(user_id, threepid, access_token) + await self.register_email_threepid(user_id, threepid, access_token) + + if self.hs.config.bind_new_user_emails_to_sydent: + # Attempt to call Sydent's internal bind API on the given identity server + # to bind this threepid + id_server_url = self.hs.config.bind_new_user_emails_to_sydent + + logger.debug( + "Attempting the bind email of %s to identity server: %s using " + "internal Sydent bind API.", + user_id, + self.hs.config.bind_new_user_emails_to_sydent, + ) + + try: + await self.identity_handler.bind_email_using_internal_sydent_api( + id_server_url, threepid["address"], user_id + ) + except Exception as e: + logger.warning( + "Failed to bind email of '%s' to Sydent instance '%s' ", + "using Sydent internal bind API: %s", + user_id, + id_server_url, + e, + ) if auth_result and LoginType.MSISDN in auth_result: threepid = auth_result[LoginType.MSISDN] @@ -731,7 +843,7 @@ class RegistrationHandler(BaseHandler): await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - async def _register_email_threepid( + async def register_email_threepid( self, user_id: str, threepid: dict, token: Optional[str] ) -> None: """Add an email address as a 3pid identifier @@ -753,7 +865,10 @@ class RegistrationHandler(BaseHandler): return await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) # And we add an email pusher for them by default, but only @@ -805,5 +920,8 @@ class RegistrationHandler(BaseHandler): raise await self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"], + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 07b2187eb1..2271c60afc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents +from synapse.rest.admin._base import assert_user_is_admin from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -197,7 +198,9 @@ class RoomCreationHandler(BaseHandler): if r is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) new_room_id = await self._generate_room_id( - creator_id=user_id, is_public=r["is_public"], room_version=new_version, + creator_id=user_id, + is_public=r["is_public"], + room_version=new_version, ) logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) @@ -235,7 +238,9 @@ class RoomCreationHandler(BaseHandler): # now send the tombstone await self.event_creation_handler.handle_new_client_event( - requester=requester, event=tombstone_event, context=tombstone_context, + requester=requester, + event=tombstone_event, + context=tombstone_context, ) old_room_state = await tombstone_context.get_current_state_ids() @@ -256,7 +261,10 @@ class RoomCreationHandler(BaseHandler): # finally, shut down the PLs in the old room, and update them in the new # room. await self._update_upgraded_room_pls( - requester, old_room_id, new_room_id, old_room_state, + requester, + old_room_id, + new_room_id, + old_room_state, ) return new_room_id @@ -363,7 +371,19 @@ class RoomCreationHandler(BaseHandler): """ user_id = requester.user.to_string() - if not await self.spam_checker.user_may_create_room(user_id): + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to create rooms + is_requester_admin = True + + else: + is_requester_admin = await self.auth.is_server_admin(requester.user) + + if not is_requester_admin and not await self.spam_checker.user_may_create_room( + user_id, invite_list=[], third_party_invite_list=[], cloning=True + ): raise SynapseError(403, "You are not permitted to create rooms") creation_content = { @@ -424,17 +444,20 @@ class RoomCreationHandler(BaseHandler): # Copy over user power levels now as this will not be possible with >100PL users once # the room has been created - # Calculate the minimum power level needed to clone the room event_power_levels = power_levels.get("events", {}) - state_default = power_levels.get("state_default", 0) - ban = power_levels.get("ban") + state_default = power_levels.get("state_default", 50) + ban = power_levels.get("ban", 50) needed_power_level = max(state_default, ban, max(event_power_levels.values())) + # Get the user's current power level, this matches the logic in get_user_power_level, + # but without the entire state map. + user_power_levels = power_levels.setdefault("users", {}) + users_default = power_levels.get("users_default", 0) + current_power_level = user_power_levels.get(user_id, users_default) # Raise the requester's power level in the new room if necessary - current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - power_levels["users"][user_id] = needed_power_level + user_power_levels[user_id] = needed_power_level await self._send_events_for_new_room( requester, @@ -566,7 +589,7 @@ class RoomCreationHandler(BaseHandler): ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, ) -> Tuple[dict, int]: - """ Creates a new room. + """Creates a new room. Args: requester: @@ -614,8 +637,14 @@ class RoomCreationHandler(BaseHandler): 403, "You are not permitted to create rooms", Codes.FORBIDDEN ) + invite_list = config.get("invite", []) + invite_3pid_list = config.get("invite_3pid", []) + if not is_requester_admin and not await self.spam_checker.user_may_create_room( - user_id + user_id, + invite_list=invite_list, + third_party_invite_list=invite_3pid_list, + cloning=False, ): raise SynapseError(403, "You are not permitted to create rooms") @@ -687,7 +716,9 @@ class RoomCreationHandler(BaseHandler): is_public = visibility == "public" room_id = await self._generate_room_id( - creator_id=user_id, is_public=is_public, room_version=room_version, + creator_id=user_id, + is_public=is_public, + room_version=room_version, ) # Check whether this visibility value is blocked by a third party module @@ -803,6 +834,7 @@ class RoomCreationHandler(BaseHandler): "invite", ratelimit=False, content=content, + new_room=True, ) for invite_3pid in invite_3pid_list: @@ -820,6 +852,7 @@ class RoomCreationHandler(BaseHandler): id_server, requester, txn_id=None, + new_room=True, id_access_token=id_access_token, ) @@ -828,7 +861,7 @@ class RoomCreationHandler(BaseHandler): if room_alias: result["room_alias"] = room_alias.to_string() - # Always wait for room creation to progate before returning + # Always wait for room creation to propagate before returning await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), "events", @@ -880,7 +913,10 @@ class RoomCreationHandler(BaseHandler): _, last_stream_id, ) = await self.event_creation_handler.create_and_send_nonmember_event( - creator, event, ratelimit=False, ignore_shadow_ban=True, + creator, + event, + ratelimit=False, + ignore_shadow_ban=True, ) return last_stream_id @@ -897,6 +933,7 @@ class RoomCreationHandler(BaseHandler): "join", ratelimit=ratelimit, content=creator_join_profile, + new_room=True, ) # We treat the power levels override specially as this needs to be one @@ -980,7 +1017,10 @@ class RoomCreationHandler(BaseHandler): return last_sent_stream_id async def _generate_room_id( - self, creator_id: str, is_public: bool, room_version: RoomVersion, + self, + creator_id: str, + is_public: bool, + room_version: RoomVersion, ): # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -1004,41 +1044,51 @@ class RoomCreationHandler(BaseHandler): class RoomContextHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.auth = hs.get_auth() self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state async def get_event_context( self, - user: UserID, + requester: Requester, room_id: str, event_id: str, limit: int, event_filter: Optional[Filter], + use_admin_priviledge: bool = False, ) -> Optional[JsonDict]: """Retrieves events, pagination tokens and state around a given event in a room. Args: - user + requester room_id event_id limit: The maximum number of events to return in total (excluding state). event_filter: the filter to apply to the events returned (excluding the target event_id) - + use_admin_priviledge: if `True`, return all events, regardless + of whether `user` has access to them. To be used **ONLY** + from the admin API. Returns: dict, or None if the event isn't found """ + user = requester.user + if use_admin_priviledge: + await assert_user_is_admin(self.auth, requester.user) + before_limit = math.floor(limit / 2.0) after_limit = limit - before_limit users = await self.store.get_users_in_room(room_id) is_peeking = user.to_string() not in users - def filter_evts(events): - return filter_events_for_client( + async def filter_evts(events): + if use_admin_priviledge: + return events + return await filter_events_for_client( self.storage, user.to_string(), events, is_peeking=is_peeking ) diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 14f14db449..373b9dcd0d 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -170,6 +170,7 @@ class RoomListHandler(BaseHandler): "world_readable": room["history_visibility"] == HistoryVisibility.WORLD_READABLE, "guest_can_join": room["guest_access"] == "can_join", + "join_rule": room["join_rules"], } # Filter out Nones – rather omit the field altogether diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a5da97cfe0..312ebc139c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2016-2020 The Matrix.org Foundation C.I.C. +# Copyright 2020 Sorunome # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import abc import logging import random @@ -31,7 +31,15 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomAlias, + RoomID, + StateMap, + UserID, + get_domain_from_id, +) from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room @@ -122,6 +130,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def remote_knock( + self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict, + ) -> Tuple[str, int]: + """Try and knock on a room that this server is not in + + Args: + remote_room_hosts: List of servers that can be used to knock via. + room_id: Room that we are trying to knock on. + user: User who is trying to knock. + content: A dict that should be used as the content of the knock event. + """ + raise NotImplementedError() + + @abc.abstractmethod async def remote_reject_invite( self, invite_event_id: str, @@ -145,6 +167,27 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise NotImplementedError() @abc.abstractmethod + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """Rescind a local knock made on a remote room. + + Args: + knock_event_id: The ID of the knock event to rescind. + txn_id: An optional transaction ID supplied by the client. + requester: The user making the request, according to the access token. + content: The content of the generated leave event. + + Returns: + A tuple containing (event_id, stream_id of the leave event). + """ + raise NotImplementedError() + + @abc.abstractmethod async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. @@ -191,7 +234,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # do it up front for efficiency.) if txn_id and requester.access_token_id: existing_event_id = await self.store.get_event_id_from_transaction_id( - room_id, requester.user.to_string(), requester.access_token_id, txn_id, + room_id, + requester.user.to_string(), + requester.access_token_id, + txn_id, ) if existing_event_id: event_pos = await self.store.get_position_for_event(existing_event_id) @@ -238,7 +284,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[target], ratelimit=ratelimit, + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, ) if event.membership == Membership.LEAVE: @@ -300,6 +350,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, ) -> Tuple[str, int]: """Update a user's membership in a room. @@ -340,6 +391,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed=third_party_signed, ratelimit=ratelimit, content=content, + new_room=new_room, require_consent=require_consent, ) @@ -356,6 +408,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): third_party_signed: Optional[dict] = None, ratelimit: bool = True, content: Optional[dict] = None, + new_room: bool = False, require_consent: bool = True, ) -> Tuple[str, int]: """Helper for update_membership. @@ -438,8 +491,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): ) block_invite = True + is_published = await self.store.is_room_published(room_id) + if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target_id, room_id + requester.user.to_string(), + target_id, + third_party_invite=None, + room_id=room_id, + new_room=new_room, + published_room=is_published, ): logger.info("Blocking invite due to spam checker") block_invite = True @@ -517,6 +577,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") + if ( + self._server_notices_mxid is not None + and requester.user.to_string() == self._server_notices_mxid + ): + # allow the server notices mxid to join rooms + is_requester_admin = True + + else: + is_requester_admin = await self.auth.is_server_admin(requester.user) + + inviter = await self._get_inviter(target.to_string(), room_id) + if not is_requester_admin: + # We assume that if the spam checker allowed the user to create + # a room then they're allowed to join it. + if not new_room and not self.spam_checker.user_may_join_room( + target.to_string(), room_id, is_invited=inviter is not None + ): + raise SynapseError(403, "Not allowed to join this room") + if not is_host_in_room: if ratelimit: time_now_s = self.clock.time() @@ -554,50 +633,79 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: - # perhaps we've been invited + # Figure out the user's current membership state for the room ( current_membership_type, current_membership_event_id, ) = await self.store.get_local_current_membership_for_user_in_room( target.to_string(), room_id ) - if ( - current_membership_type != Membership.INVITE - or not current_membership_event_id - ): + if not current_membership_type or not current_membership_event_id: logger.info( "%s sent a leave request to %s, but that is not an active room " - "on this server, and there is no pending invite", + "on this server, or there is no pending invite or knock", target, room_id, ) raise SynapseError(404, "Not a known room") - invite = await self.store.get_event(current_membership_event_id) - logger.info( - "%s rejects invite to %s from %s", target, room_id, invite.sender - ) + # perhaps we've been invited + if current_membership_type == Membership.INVITE: + invite = await self.store.get_event(current_membership_event_id) + logger.info( + "%s rejects invite to %s from %s", + target, + room_id, + invite.sender, + ) - if not self.hs.is_mine_id(invite.sender): - # send the rejection to the inviter's HS (with fallback to - # local event) - return await self.remote_reject_invite( - invite.event_id, txn_id, requester, content, + if not self.hs.is_mine_id(invite.sender): + # send the rejection to the inviter's HS (with fallback to + # local event) + return await self.remote_reject_invite( + invite.event_id, + txn_id, + requester, + content, + ) + + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath, which will also send the + # rejection out to any other servers we believe are still in the room. + + # thanks to overzealous cleaning up of event_forward_extremities in + # `delete_old_current_state_events`, it's possible to end up with no + # forward extremities here. If that happens, let's just hang the + # rejection off the invite event. + # + # see: https://github.com/matrix-org/synapse/issues/7139 + if len(latest_event_ids) == 0: + latest_event_ids = [invite.event_id] + + # or perhaps this is a remote room that a local user has knocked on + elif current_membership_type == Membership.KNOCK: + knock = await self.store.get_event(current_membership_event_id) + return await self.remote_rescind_knock( + knock.event_id, txn_id, requester, content ) - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath, which will also send the - # rejection out to any other servers we believe are still in the room. + elif effective_membership_state == Membership.KNOCK: + if not is_host_in_room: + # The knock needs to be sent over federation instead + remote_room_hosts.append(get_domain_from_id(room_id)) + + content["membership"] = Membership.KNOCK - # thanks to overzealous cleaning up of event_forward_extremities in - # `delete_old_current_state_events`, it's possible to end up with no - # forward extremities here. If that happens, let's just hang the - # rejection off the invite event. - # - # see: https://github.com/matrix-org/synapse/issues/7139 - if len(latest_event_ids) == 0: - latest_event_ids = [invite.event_id] + profile = self.profile_handler + if "displayname" not in content: + content["displayname"] = await profile.get_displayname(target) + if "avatar_url" not in content: + content["avatar_url"] = await profile.get_avatar_url(target) + + return await self.remote_knock( + remote_room_hosts, room_id, target, content + ) return await self._local_membership_update( requester=requester, @@ -813,6 +921,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): id_server: str, requester: Requester, txn_id: Optional[str], + new_room: bool = False, id_access_token: Optional[str] = None, ) -> int: """Invite a 3PID to a room. @@ -860,6 +969,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Codes.FORBIDDEN, ) + can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( + medium, address, room_id + ) + if not can_invite: + raise SynapseError( + 403, + "This third-party identifier can not be invited in this room", + Codes.FORBIDDEN, + ) + if not self._enable_lookup: raise SynapseError( 403, "Looking up third-party identifiers is denied from this server" @@ -869,6 +988,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): id_server, medium, address, id_access_token ) + is_published = await self.store.is_room_published(room_id) + + if not await self.spam_checker.user_may_invite( + requester.user.to_string(), + invitee, + third_party_invite={"medium": medium, "address": address}, + room_id=room_id, + new_room=new_room, + published_room=is_published, + ): + logger.info("Blocking invite due to spam checker") + raise SynapseError(403, "Invites have been disabled on this server") + if invitee: # Note that update_membership with an action of "invite" can raise # a ShadowBanError, but this was done above already. @@ -1056,8 +1188,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" # filter ourselves out of remote_room_hosts: do_invite_join ignores it # and if it is the only entry we'd like to return a 404 rather than a # 500. @@ -1158,6 +1289,35 @@ class RoomMemberMasterHandler(RoomMemberHandler): invite_event, txn_id, requester, content ) + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: + """ + Rescinds a local knock made on a remote room + + Args: + knock_event_id: The ID of the knock event to rescind. + txn_id: The transaction ID to use. + requester: The originator of the request. + content: The content of the leave event. + + Implements RoomMemberHandler.remote_rescind_knock + """ + # TODO: We don't yet support rescinding knocks over federation + # as we don't know which homeserver to send it to. An obvious + # candidate is the remote homeserver we originally knocked through, + # however we don't currently store that information. + + # Just rescind the knock locally + knock_event = await self.store.get_event(knock_event_id) + return await self._generate_local_out_of_band_leave( + knock_event, txn_id, requester, content + ) + async def _generate_local_out_of_band_leave( self, previous_membership_event: EventBase, @@ -1211,16 +1371,44 @@ class RoomMemberMasterHandler(RoomMemberHandler): event.internal_metadata.out_of_band_membership = True result_event = await self.event_creation_handler.handle_new_client_event( - requester, event, context, extra_users=[UserID.from_string(target_user)], + requester, + event, + context, + extra_users=[UserID.from_string(target_user)], ) # we know it was persisted, so must have a stream ordering assert result_event.internal_metadata.stream_ordering return result_event.event_id, result_event.internal_metadata.stream_ordering - async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room + async def remote_knock( + self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict, + ) -> Tuple[str, int]: + """Sends a knock to a room. Attempts to do so via one remote out of a given list. + + Args: + remote_room_hosts: A list of homeservers to try knocking through. + room_id: The ID of the room to knock on. + user: The user to knock on behalf of. + content: The content of the knock event. + + Returns: + A tuple of (event ID, stream ID). """ + # filter ourselves out of remote_room_hosts + remote_room_hosts = [ + host for host in remote_room_hosts if host != self.hs.hostname + ] + + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + return await self.federation_handler.do_knock( + remote_room_hosts, room_id, user.to_string(), content=content + ) + + async def _user_left_room(self, target: UserID, room_id: str) -> None: + """Implements RoomMemberHandler._user_left_room""" user_left_room(self.distributor, target, room_id) async def forget(self, user: UserID, room_id: str) -> None: diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index f2e88f6a5b..926d09f40c 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,10 +21,12 @@ from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler from synapse.replication.http.membership import ( ReplicationRemoteJoinRestServlet as ReplRemoteJoin, + ReplicationRemoteKnockRestServlet as ReplRemoteKnock, ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, + ReplicationRemoteRescindKnockRestServlet as ReplRescindKnock, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID logger = logging.getLogger(__name__) @@ -33,7 +36,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler): super().__init__(hs) self._remote_join_client = ReplRemoteJoin.make_client(hs) + self._remote_knock_client = ReplRemoteKnock.make_client(hs) self._remote_reject_client = ReplRejectInvite.make_client(hs) + self._remote_rescind_client = ReplRescindKnock.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) async def _remote_join( @@ -44,8 +49,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): user: UserID, content: dict, ) -> Tuple[str, int]: - """Implements RoomMemberHandler._remote_join - """ + """Implements RoomMemberHandler._remote_join""" if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") @@ -79,9 +83,51 @@ class RoomMemberWorkerHandler(RoomMemberHandler): ) return ret["event_id"], ret["stream_id"] - async def _user_left_room(self, target: UserID, room_id: str) -> None: - """Implements RoomMemberHandler._user_left_room + async def remote_rescind_knock( + self, + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ) -> Tuple[str, int]: """ + Rescinds a local knock made on a remote room + + Args: + knock_event_id: the knock event + txn_id: optional transaction ID supplied by the client + requester: user making the request, according to the access token + content: additional content to include in the leave event. + Normally an empty dict. + + Returns: + A tuple containing (event_id, stream_id of the leave event) + """ + ret = await self._remote_rescind_client( + knock_event_id=knock_event_id, + txn_id=txn_id, + requester=requester, + content=content, + ) + return ret["event_id"], ret["stream_id"] + + async def remote_knock( + self, remote_room_hosts: List[str], room_id: str, user: UserID, content: dict, + ) -> Tuple[str, int]: + """Sends a knock to a room. + + Implements RoomMemberHandler.remote_knock + """ + ret = await self._remote_knock_client( + remote_room_hosts=remote_room_hosts, + room_id=room_id, + user=user, + content=content, + ) + return ret["event_id"], ret["stream_id"] + + async def _user_left_room(self, target: UserID, room_id: str) -> None: + """Implements RoomMemberHandler._user_left_room""" await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index e88fd59749..a9645b77d8 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py
@@ -23,7 +23,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.config.saml2_config import SamlAttributeRequirement from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string @@ -122,7 +121,8 @@ class SamlHandler(BaseHandler): now = self.clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( - creation_time=now, ui_auth_session_id=ui_auth_session_id, + creation_time=now, + ui_auth_session_id=ui_auth_session_id, ) for key, value in info["headers"]: @@ -239,12 +239,10 @@ class SamlHandler(BaseHandler): # Ensure that the attributes of the logged in user meet the required # attributes. - for requirement in self._saml2_attribute_requirements: - if not _check_attribute_requirement(saml2_auth.ava, requirement): - self._sso_handler.render_error( - request, "unauthorised", "You are not authorised to log in here." - ) - return + if not self._sso_handler.check_required_attributes( + request, saml2_auth.ava, self._saml2_attribute_requirements + ): + return # Call the mapper to register/login the user try: @@ -373,21 +371,6 @@ class SamlHandler(BaseHandler): del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: - values = ava.get(req.attribute, []) - for v in values: - if v == req.value: - return True - - logger.info( - "SAML2 attribute %s did not match required value '%s' (was '%s')", - req.attribute, - req.value, - values, - ) - return False - - DOT_REPLACE_PATTERN = re.compile( ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) ) @@ -468,7 +451,8 @@ class DefaultSamlMappingProvider: mxid_source = saml_response.ava[self._mxid_source_attribute][0] except KeyError: logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + "SAML2 response lacks a '%s' attestation", + self._mxid_source_attribute, ) raise SynapseError( 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 84af2dde7e..cef6b3ae48 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2017 New Vector Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index b450668f1c..514b1f69d8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -16,10 +16,12 @@ import abc import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, Iterable, + List, Mapping, Optional, Set, @@ -34,6 +36,7 @@ from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError +from synapse.config.sso import SsoAttributeRequirement from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect @@ -324,7 +327,8 @@ class SsoHandler: # Check if we already have a mapping for this user. previously_registered_user_id = await self._store.get_user_by_external_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # A match was found, return the user ID. @@ -413,7 +417,8 @@ class SsoHandler: with await self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) # Check for grandfathering of users. @@ -458,7 +463,8 @@ class SsoHandler: ) async def _call_attribute_mapper( - self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], + self, + sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]], ) -> UserAttributes: """Call the attribute mapper function in a loop, until we get a unique userid""" for i in range(self._MAP_USERNAME_RETRIES): @@ -629,7 +635,8 @@ class SsoHandler: """ user_id = await self.get_sso_user_by_remote_user_id( - auth_provider_id, remote_user_id, + auth_provider_id, + remote_user_id, ) user_id_to_verify = await self._auth_handler.get_session_data( @@ -668,7 +675,8 @@ class SsoHandler: # render an error page. html = self._bad_user_template.render( - server_name=self._server_name, user_id_to_verify=user_id_to_verify, + server_name=self._server_name, + user_id_to_verify=user_id_to_verify, ) respond_with_html(request, 200, html) @@ -692,7 +700,9 @@ class SsoHandler: raise SynapseError(400, "unknown session") async def check_username_availability( - self, localpart: str, session_id: str, + self, + localpart: str, + session_id: str, ) -> bool: """Handle an "is username available" callback check @@ -742,7 +752,11 @@ class SsoHandler: use_display_name: whether the user wants to use the suggested display name emails_to_use: emails that the user would like to use """ - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return # update the session with the user's choices session.chosen_localpart = localpart @@ -793,7 +807,12 @@ class SsoHandler: session_id, terms_version, ) - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return + session.terms_accepted_version = terms_version # we're done; now we can register the user @@ -808,7 +827,11 @@ class SsoHandler: request: HTTP request session_id: ID of the username mapping session, extracted from a cookie """ - session = self.get_mapping_session(session_id) + try: + session = self.get_mapping_session(session_id) + except SynapseError as e: + self.render_error(request, "bad_session", e.msg, code=e.code) + return logger.info( "[session %s] Registering localpart %s", @@ -817,7 +840,8 @@ class SsoHandler: ) attributes = UserAttributes( - localpart=session.chosen_localpart, emails=session.emails_to_use, + localpart=session.chosen_localpart, + emails=session.emails_to_use, ) if session.use_display_name: @@ -880,6 +904,41 @@ class SsoHandler: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + def check_required_attributes( + self, + request: SynapseRequest, + attributes: Mapping[str, List[Any]], + attribute_requirements: Iterable[SsoAttributeRequirement], + ) -> bool: + """ + Confirm that the required attributes were present in the SSO response. + + If all requirements are met, this will return True. + + If any requirement is not met, then the request will be finalized by + showing an error page to the user and False will be returned. + + Args: + request: The request to (potentially) respond to. + attributes: The attributes from the SSO IdP. + attribute_requirements: The requirements that attributes must meet. + + Returns: + True if all requirements are met, False if any attribute fails to + meet the requirement. + + """ + # Ensure that the attributes of the logged in user meet the required + # attributes. + for requirement in attribute_requirements: + if not _check_attribute_requirement(attributes, requirement): + self.render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return False + + return True + def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: """Extract the session ID from the cookie @@ -890,3 +949,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: if not session_id: raise SynapseError(code=400, msg="missing session_id") return session_id.decode("ascii", errors="replace") + + +def _check_attribute_requirement( + attributes: Mapping[str, List[Any]], req: SsoAttributeRequirement +) -> bool: + """Check if SSO attributes meet the proper requirements. + + Args: + attributes: A mapping of attributes to an iterable of one or more values. + requirement: The configured requirement to check. + + Returns: + True if the required attribute was found and had a proper value. + """ + if req.attribute not in attributes: + logger.info("SSO attribute missing: %s", req.attribute) + return False + + # If the requirement is None, the attribute existing is enough. + if req.value is None: + return True + + values = attributes[req.attribute] + if req.value in values: + return True + + logger.info( + "SSO attribute %s did not match required value '%s' (was '%s')", + req.attribute, + req.value, + values, + ) + return False diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index d261d7cd4e..388dec5831 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py
@@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2018 New Vector Ltd +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -63,8 +65,7 @@ class StatsHandler: self.clock.call_later(0, self.notify_new_event) def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.stats_enabled or self._is_processing: return @@ -232,6 +233,8 @@ class StatsHandler: room_stats_delta["left_members"] -= 1 elif prev_membership == Membership.BAN: room_stats_delta["banned_members"] -= 1 + elif prev_membership == Membership.KNOCK: + room_stats_delta["knocked_members"] -= 1 else: raise ValueError( "%r is not a valid prev_membership" % (prev_membership,) @@ -253,6 +256,8 @@ class StatsHandler: room_stats_delta["left_members"] += 1 elif membership == Membership.BAN: room_stats_delta["banned_members"] += 1 + elif membership == Membership.KNOCK: + room_stats_delta["knocked_members"] += 1 else: raise ValueError("%r is not a valid membership" % (membership,)) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5c7590f38e..fa6794734b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -151,6 +151,16 @@ class InvitedSyncResult: @attr.s(slots=True, frozen=True) +class KnockedSyncResult: + room_id = attr.ib(type=str) + knock = attr.ib(type=EventBase) + + def __bool__(self) -> bool: + """Knocked rooms should always be reported to the client""" + return True + + +@attr.s(slots=True, frozen=True) class GroupsSyncResult: join = attr.ib(type=JsonDict) invite = attr.ib(type=JsonDict) @@ -183,6 +193,7 @@ class _RoomChanges: room_entries = attr.ib(type=List["RoomSyncResultBuilder"]) invited = attr.ib(type=List[InvitedSyncResult]) + knocked = attr.ib(type=List[KnockedSyncResult]) newly_joined_rooms = attr.ib(type=List[str]) newly_left_rooms = attr.ib(type=List[str]) @@ -196,6 +207,7 @@ class SyncResult: account_data: List of account_data events for the user. joined: JoinedSyncResult for each joined room. invited: InvitedSyncResult for each invited room. + knocked: KnockedSyncResult for each knocked on room. archived: ArchivedSyncResult for each archived room. to_device: List of direct messages for the device. device_lists: List of user_ids whose devices have changed @@ -211,6 +223,7 @@ class SyncResult: account_data = attr.ib(type=List[JsonDict]) joined = attr.ib(type=List[JoinedSyncResult]) invited = attr.ib(type=List[InvitedSyncResult]) + knocked = attr.ib(type=List[KnockedSyncResult]) archived = attr.ib(type=List[ArchivedSyncResult]) to_device = attr.ib(type=List[JsonDict]) device_lists = attr.ib(type=DeviceLists) @@ -227,6 +240,7 @@ class SyncResult: self.presence or self.joined or self.invited + or self.knocked or self.archived or self.account_data or self.to_device @@ -339,8 +353,7 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Get the sync for client needed to match what the server has now. - """ + """Get the sync for client needed to match what the server has now.""" return await self.generate_sync_result(sync_config, since_token, full_state) async def push_rules_for_user(self, user: UserID) -> JsonDict: @@ -564,7 +577,7 @@ class SyncHandler: stream_position: StreamToken, state_filter: StateFilter = StateFilter.all(), ) -> StateMap[str]: - """ Get the room state at a particular stream position + """Get the room state at a particular stream position Args: room_id: room for which to get state @@ -598,7 +611,7 @@ class SyncHandler: state: MutableStateMap[EventBase], now_token: StreamToken, ) -> Optional[JsonDict]: - """ Works out a room summary block for this room, summarising the number + """Works out a room summary block for this room, summarising the number of joined members in the room, and providing the 'hero' members if the room has no name so clients can consistently name rooms. Also adds state events to 'state' if needed to describe the heroes. @@ -743,7 +756,7 @@ class SyncHandler: now_token: StreamToken, full_state: bool, ) -> MutableStateMap[EventBase]: - """ Works out the difference in state between the start of the timeline + """Works out the difference in state between the start of the timeline and the previous sync. Args: @@ -820,8 +833,10 @@ class SyncHandler: ) elif batch.limited: if batch: - state_at_timeline_start = await self.state_store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_at_timeline_start = ( + await self.state_store.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: # We can get here if the user has ignored the senders of all @@ -955,8 +970,7 @@ class SyncHandler: since_token: Optional[StreamToken] = None, full_state: bool = False, ) -> SyncResult: - """Generates a sync result. - """ + """Generates a sync result.""" # NB: The now_token gets changed by some of the generate_sync_* methods, # this is due to some of the underlying streams not supporting the ability # to query up to a given point. @@ -999,7 +1013,7 @@ class SyncHandler: res = await self._generate_sync_entry_for_rooms( sync_result_builder, account_data_by_room ) - newly_joined_rooms, newly_joined_or_invited_users, _, _ = res + newly_joined_rooms, newly_joined_or_invited_or_knocked_users, _, _ = res _, _, newly_left_rooms, newly_left_users = res block_all_presence_data = ( @@ -1008,7 +1022,9 @@ class SyncHandler: if self.hs_config.use_presence and not block_all_presence_data: logger.debug("Fetching presence data") await self._generate_sync_entry_for_presence( - sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users + sync_result_builder, + newly_joined_rooms, + newly_joined_or_invited_or_knocked_users, ) logger.debug("Fetching to-device data") @@ -1017,7 +1033,7 @@ class SyncHandler: device_lists = await self._generate_sync_entry_for_device_list( sync_result_builder, newly_joined_rooms=newly_joined_rooms, - newly_joined_or_invited_users=newly_joined_or_invited_users, + newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users, newly_left_rooms=newly_left_rooms, newly_left_users=newly_left_users, ) @@ -1030,8 +1046,8 @@ class SyncHandler: one_time_key_counts = await self.store.count_e2e_one_time_keys( user_id, device_id ) - unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types( - user_id, device_id + unused_fallback_key_types = ( + await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) logger.debug("Fetching group data") @@ -1051,6 +1067,7 @@ class SyncHandler: account_data=sync_result_builder.account_data, joined=sync_result_builder.joined, invited=sync_result_builder.invited, + knocked=sync_result_builder.knocked, archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, @@ -1110,7 +1127,7 @@ class SyncHandler: self, sync_result_builder: "SyncResultBuilder", newly_joined_rooms: Set[str], - newly_joined_or_invited_users: Set[str], + newly_joined_or_invited_or_knocked_users: Set[str], newly_left_rooms: Set[str], newly_left_users: Set[str], ) -> DeviceLists: @@ -1119,8 +1136,9 @@ class SyncHandler: Args: sync_result_builder newly_joined_rooms: Set of rooms user has joined since previous sync - newly_joined_or_invited_users: Set of users that have joined or - been invited to a room since previous sync. + newly_joined_or_invited_or_knocked_users: Set of users that have joined, + been invited to a room or are knocking on a room since + previous sync. newly_left_rooms: Set of rooms user has left since previous sync newly_left_users: Set of users that have left a room we're in since previous sync @@ -1131,7 +1149,9 @@ class SyncHandler: # We're going to mutate these fields, so lets copy them rather than # assume they won't get used later. - newly_joined_or_invited_users = set(newly_joined_or_invited_users) + newly_joined_or_invited_or_knocked_users = set( + newly_joined_or_invited_or_knocked_users + ) newly_left_users = set(newly_left_users) if since_token and since_token.device_list_key: @@ -1170,14 +1190,16 @@ class SyncHandler: # Step 1b, check for newly joined rooms for room_id in newly_joined_rooms: joined_users = await self.state.get_current_users_in_room(room_id) - newly_joined_or_invited_users.update(joined_users) + newly_joined_or_invited_or_knocked_users.update(joined_users) # TODO: Check that these users are actually new, i.e. either they # weren't in the previous sync *or* they left and rejoined. - users_that_have_changed.update(newly_joined_or_invited_users) + users_that_have_changed.update(newly_joined_or_invited_or_knocked_users) - user_signatures_changed = await self.store.get_users_whose_signatures_changed( - user_id, since_token.device_list_key + user_signatures_changed = ( + await self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) ) users_that_have_changed.update(user_signatures_changed) @@ -1393,8 +1415,10 @@ class SyncHandler: logger.debug("no-oping sync") return set(), set(), set(), set() - ignored_account_data = await self.store.get_global_account_data_by_type_for_user( - AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ignored_account_data = ( + await self.store.get_global_account_data_by_type_for_user( + AccountDataTypes.IGNORED_USER_LIST, user_id=user_id + ) ) # If there is ignored users account data and it matches the proper type, @@ -1419,6 +1443,7 @@ class SyncHandler: room_entries = room_changes.room_entries invited = room_changes.invited + knocked = room_changes.knocked newly_joined_rooms = room_changes.newly_joined_rooms newly_left_rooms = room_changes.newly_left_rooms @@ -1439,9 +1464,10 @@ class SyncHandler: await concurrently_execute(handle_room_entries, room_entries, 10) sync_result_builder.invited.extend(invited) + sync_result_builder.knocked.extend(knocked) - # Now we want to get any newly joined or invited users - newly_joined_or_invited_users = set() + # Now we want to get any newly joined, invited or knocking users + newly_joined_or_invited_or_knocked_users = set() newly_left_users = set() if since_token: for joined_sync in sync_result_builder.joined: @@ -1453,19 +1479,22 @@ class SyncHandler: if ( event.membership == Membership.JOIN or event.membership == Membership.INVITE + or event.membership == Membership.KNOCK ): - newly_joined_or_invited_users.add(event.state_key) + newly_joined_or_invited_or_knocked_users.add( + event.state_key + ) else: prev_content = event.unsigned.get("prev_content", {}) prev_membership = prev_content.get("membership", None) if prev_membership == Membership.JOIN: newly_left_users.add(event.state_key) - newly_left_users -= newly_joined_or_invited_users + newly_left_users -= newly_joined_or_invited_or_knocked_users return ( set(newly_joined_rooms), - newly_joined_or_invited_users, + newly_joined_or_invited_or_knocked_users, set(newly_left_rooms), newly_left_users, ) @@ -1499,8 +1528,7 @@ class SyncHandler: async def _get_rooms_changed( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] ) -> _RoomChanges: - """Gets the the changes that have happened since the last sync. - """ + """Gets the the changes that have happened since the last sync.""" user_id = sync_result_builder.sync_config.user.to_string() since_token = sync_result_builder.since_token now_token = sync_result_builder.now_token @@ -1521,6 +1549,7 @@ class SyncHandler: newly_left_rooms = [] room_entries = [] invited = [] + knocked = [] for room_id, events in mem_change_events_by_room_id.items(): logger.debug( "Membership changes in %s: [%s]", @@ -1600,9 +1629,17 @@ class SyncHandler: should_invite = non_joins[-1].membership == Membership.INVITE if should_invite: if event.sender not in ignored_users: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) + invite_room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if invite_room_sync: + invited.append(invite_room_sync) + + # Only bother if our latest membership in the room is knock (and we haven't + # been accepted/rejected in the meantime). + should_knock = non_joins[-1].membership == Membership.KNOCK + if should_knock: + knock_room_sync = KnockedSyncResult(room_id, knock=non_joins[-1]) + if knock_room_sync: + knocked.append(knock_room_sync) # Always include leave/ban events. Just take the last one. # TODO: How do we handle ban -> leave in same batch? @@ -1706,7 +1743,9 @@ class SyncHandler: ) room_entries.append(entry) - return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms) + return _RoomChanges( + room_entries, invited, knocked, newly_joined_rooms, newly_left_rooms, + ) async def _get_all_rooms( self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] @@ -1726,6 +1765,7 @@ class SyncHandler: membership_list = ( Membership.INVITE, + Membership.KNOCK, Membership.JOIN, Membership.LEAVE, Membership.BAN, @@ -1737,6 +1777,7 @@ class SyncHandler: room_entries = [] invited = [] + knocked = [] for event in room_list: if event.membership == Membership.JOIN: @@ -1756,8 +1797,11 @@ class SyncHandler: continue invite = await self.store.get_event(event.event_id) invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite)) + elif event.membership == Membership.KNOCK: + knock = await self.store.get_event(event.event_id) + knocked.append(KnockedSyncResult(room_id=event.room_id, knock=knock)) elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. + # Always send down rooms we were banned from or kicked from. if not sync_config.filter_collection.include_leave: if event.membership == Membership.LEAVE: if user_id == event.sender: @@ -1778,7 +1822,7 @@ class SyncHandler: ) ) - return _RoomChanges(room_entries, invited, [], []) + return _RoomChanges(room_entries, invited, knocked, [], []) async def _generate_room_entry( self, @@ -2067,6 +2111,7 @@ class SyncResultBuilder: account_data (list) joined (list[JoinedSyncResult]) invited (list[InvitedSyncResult]) + knocked (list[KnockedSyncResult]) archived (list[ArchivedSyncResult]) groups (GroupsSyncResult|None) to_device (list) @@ -2082,6 +2127,7 @@ class SyncResultBuilder: account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list)) joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list)) invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list)) + knocked = attr.ib(type=List[KnockedSyncResult], default=attr.Factory(list)) archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list)) groups = attr.ib(type=Optional[GroupsSyncResult], default=None) to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list)) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3f0dfc7a74..096d199f4c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -61,7 +61,8 @@ class FollowerTypingHandler: if hs.config.worker.writers.typing != hs.get_instance_name(): hs.get_federation_registry().register_instance_for_edu( - "m.typing", hs.config.worker.writers.typing, + "m.typing", + hs.config.worker.writers.typing, ) # map room IDs to serial numbers @@ -76,8 +77,7 @@ class FollowerTypingHandler: self.clock.looping_call(self._handle_timeouts, 5000) def _reset(self) -> None: - """Reset the typing handler's data caches. - """ + """Reset the typing handler's data caches.""" # map room IDs to serial numbers self._room_serials = {} # map room IDs to sets of users currently typing @@ -149,8 +149,7 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] ) -> None: - """Should be called whenever we receive updates for typing stream. - """ + """Should be called whenever we receive updates for typing stream.""" if self._latest_room_serial > token: # The master has gone backwards. To prevent inconsistent data, just diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8aedf5072e..1a8340000a 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py
@@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler): return results def notify_new_event(self) -> None: - """Called when there may be more deltas to process - """ + """Called when there may be more deltas to process""" if not self.update_user_directory: return @@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler): ) async def handle_user_deactivated(self, user_id: str) -> None: - """Called when a user ID is deactivated - """ + """Called when a user ID is deactivated""" # FIXME(#3714): We should probably do this in the same worker as all # the other changes. await self.store.remove_from_user_dir(user_id) @@ -145,6 +143,10 @@ class UserDirectoryHandler(StateDeltasHandler): if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() + # If still None then the initial background update hasn't happened yet. + if self.pos is None: + return None + # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): @@ -172,8 +174,7 @@ class UserDirectoryHandler(StateDeltasHandler): await self.store.update_user_directory_stream_pos(max_pos) async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None: - """Called with the state deltas to process - """ + """Called with the state deltas to process""" for delta in deltas: typ = delta["type"] state_key = delta["state_key"] diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 4bc3cb53f0..c658862fe6 100644 --- a/synapse/http/__init__.py +++ b/synapse/http/__init__.py
@@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer): def get_request_user_agent(request: IRequest, default: str = "") -> str: - """Return the last User-Agent header, or the given default. - """ + """Return the last User-Agent header, or the given default.""" # There could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37ccf5ab98..a910548f1e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -56,7 +56,7 @@ from twisted.web.client import ( ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IResponse +from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri @@ -289,8 +289,7 @@ class SimpleHttpClient: treq_args: Dict[str, Any] = {}, ip_whitelist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None, - http_proxy: Optional[bytes] = None, - https_proxy: Optional[bytes] = None, + use_proxy: bool = False, ): """ Args: @@ -300,8 +299,8 @@ class SimpleHttpClient: we may not request. ip_whitelist: The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. - http_proxy: proxy server to use for http connections. host[:port] - https_proxy: proxy server to use for https connections. host[:port] + use_proxy: Whether proxy settings should be discovered and used + from conventional environment variables. """ self.hs = hs @@ -345,8 +344,7 @@ class SimpleHttpClient: connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, - http_proxy=http_proxy, - https_proxy=https_proxy, + use_proxy=use_proxy, ) if self._ip_blacklist: @@ -398,7 +396,8 @@ class SimpleHttpClient: body_producer = None if data is not None: body_producer = QuieterFileBodyProducer( - BytesIO(data), cooperator=self._cooperator, + BytesIO(data), + cooperator=self._cooperator, ) request_deferred = treq.request( @@ -407,13 +406,18 @@ class SimpleHttpClient: agent=self.agent, data=body_producer, headers=headers, + # Avoid buffering the body in treq since we do not reuse + # response bodies. + unbuffered=True, **self._extra_treq_args, ) # type: defer.Deferred # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, 60, self.hs.get_reactor(), + request_deferred, + 60, + self.hs.get_reactor(), ) # turn timeouts into RequestTimedOutErrors @@ -699,18 +703,6 @@ class SimpleHttpClient: resp_headers = dict(response.headers.getAllRawHeaders()) - if ( - b"Content-Length" in resp_headers - and max_size - and int(resp_headers[b"Content-Length"][0]) > max_size - ): - logger.warning("Requested URL is too large > %r bytes" % (max_size,)) - raise SynapseError( - 502, - "Requested file is too large > %r bytes" % (max_size,), - Codes.TOO_LARGE, - ) - if response.code > 299: logger.warning("Got %d when downloading %s" % (response.code, url)) raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) @@ -777,7 +769,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): # in the meantime. if self.max_size is not None and self.length >= self.max_size: self.deferred.errback(BodyExceededMaxSize()) - self.transport.loseConnection() + # Close the connection (forcefully) since all the data will get + # discarded anyway. + self.transport.abortConnection() def connectionLost(self, reason: Failure) -> None: # If the maximum size was already exceeded, there's nothing to do. @@ -811,6 +805,11 @@ def read_body_with_max_size( Returns: A Deferred which resolves to the length of the read body. """ + # If the Content-Length header gives a size larger than the maximum allowed + # size, do not bother downloading the body. + if max_size is not None and response.length != UNKNOWN_LENGTH: + if response.length > max_size: + return defer.fail(BodyExceededMaxSize()) d = defer.Deferred() response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size)) diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 856e28454f..b797e3ce80 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py
@@ -19,9 +19,10 @@ from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import ConnectError -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.internet.protocol import connectionDone +from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.protocol import ClientFactory, Protocol, connectionDone from twisted.web import http +from twisted.web.http_headers import Headers logger = logging.getLogger(__name__) @@ -43,23 +44,33 @@ class HTTPConnectProxyEndpoint: Args: reactor: the Twisted reactor to use for the connection - proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the - proxy - host (bytes): hostname that we want to CONNECT to - port (int): port that we want to connect to + proxy_endpoint: the endpoint to use to connect to the proxy + host: hostname that we want to CONNECT to + port: port that we want to connect to + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, reactor, proxy_endpoint, host, port): + def __init__( + self, + reactor: IReactorCore, + proxy_endpoint: IStreamClientEndpoint, + host: bytes, + port: int, + headers: Headers, + ): self._reactor = reactor self._proxy_endpoint = proxy_endpoint self._host = host self._port = port + self._headers = headers def __repr__(self): return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,) - def connect(self, protocolFactory): - f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory) + def connect(self, protocolFactory: ClientFactory): + f = HTTPProxiedClientFactory( + self._host, self._port, protocolFactory, self._headers + ) d = self._proxy_endpoint.connect(f) # once the tcp socket connects successfully, we need to wait for the # CONNECT to complete. @@ -74,15 +85,23 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): HTTP Protocol object and run the rest of the connection. Args: - dst_host (bytes): hostname that we want to CONNECT to - dst_port (int): port that we want to connect to - wrapped_factory (protocol.ClientFactory): The original Factory + dst_host: hostname that we want to CONNECT to + dst_port: port that we want to connect to + wrapped_factory: The original Factory + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, dst_host, dst_port, wrapped_factory): + def __init__( + self, + dst_host: bytes, + dst_port: int, + wrapped_factory: ClientFactory, + headers: Headers, + ): self.dst_host = dst_host self.dst_port = dst_port self.wrapped_factory = wrapped_factory + self.headers = headers self.on_connection = defer.Deferred() def startedConnecting(self, connector): @@ -92,7 +111,11 @@ class HTTPProxiedClientFactory(protocol.ClientFactory): wrapped_protocol = self.wrapped_factory.buildProtocol(addr) return HTTPConnectProtocol( - self.dst_host, self.dst_port, wrapped_protocol, self.on_connection + self.dst_host, + self.dst_port, + wrapped_protocol, + self.on_connection, + self.headers, ) def clientConnectionFailed(self, connector, reason): @@ -112,24 +135,37 @@ class HTTPConnectProtocol(protocol.Protocol): """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect Args: - host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal + host: The original HTTP(s) hostname or IPv4 or IPv6 address literal to put in the CONNECT request - port (int): The original HTTP(s) port to put in the CONNECT request + port: The original HTTP(s) port to put in the CONNECT request - wrapped_protocol (interfaces.IProtocol): the original protocol (probably - HTTPChannel or TLSMemoryBIOProtocol, but could be anything really) + wrapped_protocol: the original protocol (probably HTTPChannel or + TLSMemoryBIOProtocol, but could be anything really) - connected_deferred (Deferred): a Deferred which will be callbacked with + connected_deferred: a Deferred which will be callbacked with wrapped_protocol when the CONNECT completes + + headers: Extra HTTP headers to include in the CONNECT request """ - def __init__(self, host, port, wrapped_protocol, connected_deferred): + def __init__( + self, + host: bytes, + port: int, + wrapped_protocol: Protocol, + connected_deferred: defer.Deferred, + headers: Headers, + ): self.host = host self.port = port self.wrapped_protocol = wrapped_protocol self.connected_deferred = connected_deferred - self.http_setup_client = HTTPConnectSetupClient(self.host, self.port) + self.headers = headers + + self.http_setup_client = HTTPConnectSetupClient( + self.host, self.port, self.headers + ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) def connectionMade(self): @@ -154,7 +190,7 @@ class HTTPConnectProtocol(protocol.Protocol): if buf: self.wrapped_protocol.dataReceived(buf) - def dataReceived(self, data): + def dataReceived(self, data: bytes): # if we've set up the HTTP protocol, we can send the data there if self.wrapped_protocol.connected: return self.wrapped_protocol.dataReceived(data) @@ -168,21 +204,29 @@ class HTTPConnectSetupClient(http.HTTPClient): """HTTPClient protocol to send a CONNECT message for proxies and read the response. Args: - host (bytes): The hostname to send in the CONNECT message - port (int): The port to send in the CONNECT message + host: The hostname to send in the CONNECT message + port: The port to send in the CONNECT message + headers: Extra headers to send with the CONNECT message """ - def __init__(self, host, port): + def __init__(self, host: bytes, port: int, headers: Headers): self.host = host self.port = port + self.headers = headers self.on_connected = defer.Deferred() def connectionMade(self): logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) + + # Send any additional specified headers + for name, values in self.headers.getAllRawHeaders(): + for value in values: + self.sendHeader(name, value) + self.endHeaders() - def handleStatus(self, version, status, message): + def handleStatus(self, version: bytes, status: bytes, message: bytes): logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": raise ProxyConnectError("Unexpected status on CONNECT: %s" % status) diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 4c06a117d3..2e83fa6773 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -195,8 +195,7 @@ class MatrixFederationAgent: @implementer(IAgentEndpointFactory) class MatrixHostnameEndpointFactory: - """Factory for MatrixHostnameEndpoint for parsing to an Agent. - """ + """Factory for MatrixHostnameEndpoint for parsing to an Agent.""" def __init__( self, @@ -261,8 +260,7 @@ class MatrixHostnameEndpoint: self._srv_resolver = srv_resolver def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: - """Implements IStreamClientEndpoint interface - """ + """Implements IStreamClientEndpoint interface""" return run_in_background(self._do_connect, protocol_factory) @@ -323,12 +321,19 @@ class MatrixHostnameEndpoint: if port or _is_ip_literal(host): return [Server(host, port or 8448)] + logger.debug("Looking up SRV record for %s", host.decode(errors="replace")) server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host) if server_list: + logger.debug( + "Got %s from SRV lookup for %s", + ", ".join(map(str, server_list)), + host.decode(errors="replace"), + ) return server_list # No SRV records, so we fallback to host and 8448 + logger.debug("No SRV records for %s", host.decode(errors="replace")) return [Server(host, 8448)] diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index b3b6dbcab0..4def7d7633 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py
@@ -81,8 +81,7 @@ class WellKnownLookupResult: class WellKnownResolver: - """Handles well-known lookups for matrix servers. - """ + """Handles well-known lookups for matrix servers.""" def __init__( self, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 19293bf673..cde42e9f5e 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -254,7 +254,8 @@ class MatrixFederationHttpClient: # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, + self.agent, + ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() @@ -652,7 +653,7 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, ) -> Union[JsonDict, list]: - """ Sends the specified json data using PUT + """Sends the specified json data using PUT Args: destination: The remote server to send the HTTP request to. @@ -740,7 +741,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, args: Optional[QueryArgs] = None, ) -> Union[JsonDict, list]: - """ Sends the specified json data using POST + """Sends the specified json data using POST Args: destination: The remote server to send the HTTP request to. @@ -799,7 +800,11 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms, + self.reactor, + _sec_timeout, + request, + response, + start_ms, ) return body @@ -813,7 +818,7 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, ) -> Union[JsonDict, list]: - """ GETs some json from the given host homeserver and path + """GETs some json from the given host homeserver and path Args: destination: The remote server to send the HTTP request to. @@ -994,7 +999,10 @@ class MatrixFederationHttpClient: except BodyExceededMaxSize: msg = "Requested file is too large > %r bytes" % (max_size,) logger.warning( - "{%s} [%s] %s", request.txn_id, request.destination, msg, + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, ) raise SynapseError(502, msg, Codes.TOO_LARGE) except Exception as e: diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b730d2c634..ee65a6668b 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -12,9 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging import re +from typing import Optional, Tuple +from urllib.request import getproxies_environment, proxy_bypass_environment +import attr from zope.interface import implementer from twisted.internet import defer @@ -22,6 +26,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.python.failure import Failure from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase from twisted.web.error import SchemeNotSupported +from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint @@ -31,6 +36,22 @@ logger = logging.getLogger(__name__) _VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z") +@attr.s +class ProxyCredentials: + username_password = attr.ib(type=bytes) + + def as_proxy_authorization_value(self) -> bytes: + """ + Return the value for a Proxy-Authorization header (i.e. 'Basic abdef=='). + + Returns: + A transformation of the authentication string the encoded value for + a Proxy-Authorization header. + """ + # Encode as base64 and prepend the authorization type + return b"Basic " + base64.encodebytes(self.username_password) + + @implementer(IAgent) class ProxyAgent(_AgentBase): """An Agent implementation which will use an HTTP proxy if one was requested @@ -58,6 +79,9 @@ class ProxyAgent(_AgentBase): pool (HTTPConnectionPool|None): connection pool to be used. If None, a non-persistent pool instance will be created. + + use_proxy (bool): Whether proxy settings should be discovered and used + from conventional environment variables. """ def __init__( @@ -68,8 +92,7 @@ class ProxyAgent(_AgentBase): connectTimeout=None, bindAddress=None, pool=None, - http_proxy=None, - https_proxy=None, + use_proxy=False, ): _AgentBase.__init__(self, reactor, pool) @@ -84,6 +107,18 @@ class ProxyAgent(_AgentBase): if bindAddress is not None: self._endpoint_kwargs["bindAddress"] = bindAddress + http_proxy = None + https_proxy = None + no_proxy = None + if use_proxy: + proxies = getproxies_environment() + http_proxy = proxies["http"].encode() if "http" in proxies else None + https_proxy = proxies["https"].encode() if "https" in proxies else None + no_proxy = proxies["no"] if "no" in proxies else None + + # Parse credentials from https proxy connection string if present + self.https_proxy_creds, https_proxy = parse_username_password(https_proxy) + self.http_proxy_endpoint = _http_proxy_endpoint( http_proxy, self.proxy_reactor, **self._endpoint_kwargs ) @@ -92,6 +127,8 @@ class ProxyAgent(_AgentBase): https_proxy, self.proxy_reactor, **self._endpoint_kwargs ) + self.no_proxy = no_proxy + self._policy_for_https = contextFactory self._reactor = reactor @@ -139,18 +176,43 @@ class ProxyAgent(_AgentBase): pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) request_path = parsed_uri.originForm - if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: + should_skip_proxy = False + if self.no_proxy is not None: + should_skip_proxy = proxy_bypass_environment( + parsed_uri.host.decode(), proxies={"no": self.no_proxy}, + ) + + if ( + parsed_uri.scheme == b"http" + and self.http_proxy_endpoint + and not should_skip_proxy + ): # Cache *all* connections under the same key, since we are only # connecting to a single destination, the proxy: pool_key = ("http-proxy", self.http_proxy_endpoint) endpoint = self.http_proxy_endpoint request_path = uri - elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: + elif ( + parsed_uri.scheme == b"https" + and self.https_proxy_endpoint + and not should_skip_proxy + ): + connect_headers = Headers() + + # Determine whether we need to set Proxy-Authorization headers + if self.https_proxy_creds: + # Set a Proxy-Authorization header + connect_headers.addRawHeader( + b"Proxy-Authorization", + self.https_proxy_creds.as_proxy_authorization_value(), + ) + endpoint = HTTPConnectProxyEndpoint( self.proxy_reactor, self.https_proxy_endpoint, parsed_uri.host, parsed_uri.port, + headers=connect_headers, ) else: # not using a proxy @@ -179,12 +241,16 @@ class ProxyAgent(_AgentBase): ) -def _http_proxy_endpoint(proxy, reactor, **kwargs): +def _http_proxy_endpoint(proxy: Optional[bytes], reactor, **kwargs): """Parses an http proxy setting and returns an endpoint for the proxy Args: - proxy (bytes|None): the proxy setting + proxy: the proxy setting in the form: [<username>:<password>@]<host>[:<port>] + Note that compared to other apps, this function currently lacks support + for specifying a protocol schema (i.e. protocol://...). + reactor: reactor to be used to connect to the proxy + kwargs: other args to be passed to HostnameEndpoint Returns: @@ -194,16 +260,43 @@ def _http_proxy_endpoint(proxy, reactor, **kwargs): if proxy is None: return None - # currently we only support hostname:port. Some apps also support - # protocol://<host>[:port], which allows a way of requiring a TLS connection to the - # proxy. - + # Parse the connection string host, port = parse_host_port(proxy, default_port=1080) return HostnameEndpoint(reactor, host, port, **kwargs) -def parse_host_port(hostport, default_port=None): - # could have sworn we had one of these somewhere else... +def parse_username_password(proxy: bytes) -> Tuple[Optional[ProxyCredentials], bytes]: + """ + Parses the username and password from a proxy declaration e.g + username:password@hostname:port. + + Args: + proxy: The proxy connection string. + + Returns + An instance of ProxyCredentials and the proxy connection string with any credentials + stripped, i.e u:p@host:port -> host:port. If no credentials were found, the + ProxyCredentials instance is replaced with None. + """ + if proxy and b"@" in proxy: + # We use rsplit here as the password could contain an @ character + credentials, proxy_without_credentials = proxy.rsplit(b"@", 1) + return ProxyCredentials(credentials), proxy_without_credentials + + return None, proxy + + +def parse_host_port(hostport: bytes, default_port: int = None) -> Tuple[bytes, int]: + """ + Parse the hostname and port from a proxy connection byte string. + + Args: + hostport: The proxy connection string. Must be in the form 'host[:port]'. + default_port: The default port to return if one is not found in `hostport`. + + Returns: + A tuple containing the hostname and port. Uses `default_port` if one was not found. + """ if b":" in hostport: host, port = hostport.rsplit(b":", 1) try: diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 7c5defec82..0ec5d941b8 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py
@@ -213,8 +213,7 @@ class RequestMetrics: self.update_metrics() def update_metrics(self): - """Updates the in flight metrics with values from this request. - """ + """Updates the in flight metrics with values from this request.""" new_stats = self.start_context.get_resource_usage() diff = new_stats - self._request_stats diff --git a/synapse/http/server.py b/synapse/http/server.py
index 8249732b27..845db9b78d 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -76,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: - """Sends a JSON error response to clients. - """ + """Sends a JSON error response to clients.""" if f.check(SynapseError): error_code = f.value.code @@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, error_code, error_dict, send_cors=True, + request, + error_code, + error_dict, + send_cors=True, ) def return_html_error( - f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template], + f: failure.Failure, + request: Request, + error_template: Union[str, jinja2.Template], ) -> None: """Sends an HTML error page corresponding to the given failure. @@ -189,8 +193,7 @@ ServletCallback = Callable[ class HttpServer(Protocol): - """ Interface for registering callbacks on a HTTP server - """ + """Interface for registering callbacks on a HTTP server""" def register_paths( self, @@ -199,7 +202,7 @@ class HttpServer(Protocol): callback: ServletCallback, servlet_classname: str, ) -> None: - """ Register a callback that gets fired if we receive a http request + """Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. If the regex contains groups these gets passed to the callback via @@ -235,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): self._extract_context = extract_context def render(self, request): - """ This gets called by twisted every time someone sends us a request. - """ + """This gets called by twisted every time someone sends us a request.""" defer.ensureDeferred(self._async_render_wrapper(request)) return NOT_DONE_YET @@ -287,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): @abc.abstractmethod def _send_response( - self, request: SynapseRequest, code: int, response_object: Any, + self, + request: SynapseRequest, + code: int, + response_object: Any, ) -> None: raise NotImplementedError() @abc.abstractmethod def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: raise NotImplementedError() @@ -308,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource): self.canonical_json = canonical_json def _send_response( - self, request: Request, code: int, response_object: Any, + self, + request: Request, + code: int, + response_object: Any, ): - """Implements _AsyncResource._send_response - """ + """Implements _AsyncResource._send_response""" # TODO: Only enable CORS for the requests that need it. respond_with_json( request, @@ -322,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource): ) def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: - """Implements _AsyncResource._send_error_response - """ + """Implements _AsyncResource._send_error_response""" return_json_error(f, request) class JsonResource(DirectServeJsonResource): - """ This implements the HttpServer interface and provides JSON support for + """This implements the HttpServer interface and provides JSON support for Resources. Register callbacks via register_paths() @@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource): ERROR_TEMPLATE = HTML_ERROR_TEMPLATE def _send_response( - self, request: SynapseRequest, code: int, response_object: Any, + self, + request: SynapseRequest, + code: int, + response_object: Any, ): - """Implements _AsyncResource._send_response - """ + """Implements _AsyncResource._send_response""" # We expect to get bytes for us to write assert isinstance(response_object, bytes) html_bytes = response_object @@ -454,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource): respond_with_html_bytes(request, 200, html_bytes) def _send_error_response( - self, f: failure.Failure, request: SynapseRequest, + self, + f: failure.Failure, + request: SynapseRequest, ) -> None: - """Implements _AsyncResource._send_error_response - """ + """Implements _AsyncResource._send_error_response""" return_html_error(f, request, self.ERROR_TEMPLATE) @@ -534,7 +547,9 @@ class _ByteProducer: min_chunk_size = 1024 def __init__( - self, request: Request, iterator: Iterator[bytes], + self, + request: Request, + iterator: Iterator[bytes], ): self._request = request self._iterator = iterator @@ -654,7 +669,10 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, code: int, json_bytes: bytes, send_cors: bool = False, + request: Request, + code: int, + json_bytes: bytes, + send_cors: bool = False, ): """Sends encoded JSON in response to the given request. @@ -769,7 +787,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None: def finish_request(request: Request): - """ Finish writing the response to the request. + """Finish writing the response to the request. Twisted throws a RuntimeException if the connection closed before the response was written but doesn't provide a convenient or reliable way to diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b361b7cbaf..839d58d0d4 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -14,8 +14,8 @@ # limitations under the License. """ This module contains base REST classes for constructing REST servlets. """ - import logging +from typing import Dict, List, Optional, Union from synapse.api.errors import Codes, SynapseError from synapse.util import json_decoder @@ -147,16 +147,67 @@ def parse_string( ) +def parse_list_from_args( + args: Dict[bytes, List[bytes]], + name: Union[bytes, str], + encoding: Optional[str] = "ascii", +): + """Parse and optionally decode a list of values from request query parameters. + + Args: + args: A dictionary of query parameters from a request. + name: The name of the query parameter to extract values from. If given as bytes, + will be decoded as "ascii". + encoding: An optional encoding that is used to decode each parameter value with. + + Raises: + KeyError: If the given `name` does not exist in `args`. + SynapseError: If an argument was not encoded with the specified `encoding`. + """ + if not isinstance(name, bytes): + name = name.encode("ascii") + args_list = args[name] + + if encoding: + # Decode each argument value + try: + args_list = [value.decode(encoding) for value in args_list] + except ValueError: + raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding)) + + return args_list + + def parse_string_from_args( - args, - name, - default=None, - required=False, - allowed_values=None, - param_type="string", - encoding="ascii", + args: Dict[bytes, List[bytes]], + name: Union[bytes, str], + default: Optional[str] = None, + required: Optional[bool] = False, + allowed_values: Optional[List[bytes]] = None, + param_type: Optional[str] = "string", + encoding: Optional[str] = "ascii", ): + """Parse and optionally decode a single value from request query parameters. + Args: + args: A dictionary of query parameters from a request. + name: The name of the query parameter to extract values from. If given as bytes, + will be decoded as "ascii". + default: A default value to return if the given argument `name` was not found. + required: If this is True, no `default` is provided and the given argument `name` + was not found then a SynapseError is raised. + allowed_values: A list of allowed values. If specified and the found str is + not in this list, a SynapseError is raised. + param_type: The expected type of the query parameter's value. + encoding: An optional encoding that is used to decode each parameter value with. + + Returns: + The found argument value. + + Raises: + SynapseError: If the given name was not found in the request arguments, + the argument's values were encoded incorrectly or a required value was missing. + """ if not isinstance(name, bytes): name = name.encode("ascii") @@ -258,7 +309,7 @@ def assert_params_in_dict(body, required): class RestServlet: - """ A Synapse REST Servlet. + """A Synapse REST Servlet. An implementing class can either provide its own custom 'register' method, or use the automatic pattern handling provided by the base class. diff --git a/synapse/http/site.py b/synapse/http/site.py
index 12ec3f851f..4a4fb5ef26 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -249,8 +249,7 @@ class SynapseRequest(Request): ) def _finished_processing(self): - """Log the completion of this request and update the metrics - """ + """Log the completion of this request and update the metrics""" assert self.logcontext is not None usage = self.logcontext.get_resource_usage() @@ -276,7 +275,8 @@ class SynapseRequest(Request): # authenticated (e.g. and admin is puppetting a user) then we log both. if self.requester.user.to_string() != authenticated_entity: authenticated_entity = "{},{}".format( - authenticated_entity, self.requester.user.to_string(), + authenticated_entity, + self.requester.user.to_string(), ) elif self.requester is not None: # This shouldn't happen, but we log it so we don't lose information @@ -322,8 +322,7 @@ class SynapseRequest(Request): logger.warning("Failed to stop metrics: %r", e) def _should_log_request(self) -> bool: - """Whether we should log at INFO that we processed the request. - """ + """Whether we should log at INFO that we processed the request.""" if self.path == b"/health": return False diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index fb937b3f28..f8e9112b56 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py
@@ -174,7 +174,9 @@ class RemoteHandler(logging.Handler): # Make a new producer and start it. self._producer = LogProducer( - buffer=self._buffer, transport=result.transport, format=self.format, + buffer=self._buffer, + transport=result.transport, + format=self.format, ) result.transport.registerProducer(self._producer, True) self._producer.resumeProducing() diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 14d9c104c2..3e054f615c 100644 --- a/synapse/logging/_structured.py +++ b/synapse/logging/_structured.py
@@ -60,7 +60,10 @@ def parse_drain_configs( ) # Either use the default formatter or the tersejson one. - if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,): + if logging_type in ( + DrainType.CONSOLE_JSON, + DrainType.FILE_JSON, + ): formatter = "json" # type: Optional[str] elif logging_type in ( DrainType.CONSOLE_JSON_TERSE, @@ -131,7 +134,9 @@ def parse_drain_configs( ) -def setup_structured_logging(log_config: dict,) -> dict: +def setup_structured_logging( + log_config: dict, +) -> dict: """ Convert a legacy structured logging configuration (from Synapse < v1.23.0) to one compatible with the new standard library handlers. diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index c2db8b45f3..78e27bfb00 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -338,7 +338,10 @@ class LoggingContext: if self.previous_context != old_context: logcontext_error( "Expected previous context %r, found %r" - % (self.previous_context, old_context,) + % ( + self.previous_context, + old_context, + ) ) return self @@ -562,7 +565,7 @@ class LoggingContextFilter(logging.Filter): class PreserveLoggingContext: """Context manager which replaces the logging context - The previous logging context is restored on exit.""" + The previous logging context is restored on exit.""" __slots__ = ["_old_context", "_new_context"] @@ -585,7 +588,10 @@ class PreserveLoggingContext: else: logcontext_error( "Expected logging context %s but found %s" - % (self._new_context, context,) + % ( + self._new_context, + context, + ) ) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 0538350f38..10bd4a1461 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py
@@ -238,8 +238,7 @@ try: @attr.s(slots=True, frozen=True) class _WrappedRustReporter: - """Wrap the reporter to ensure `report_span` never throws. - """ + """Wrap the reporter to ensure `report_span` never throws.""" _reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter)) @@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs): def init_tracer(hs: "HomeServer"): - """Set the whitelists and initialise the JaegerClient tracer - """ + """Set the whitelists and initialise the JaegerClient tracer""" global opentracing if not hs.config.opentracer_enabled: # We don't have a tracer @@ -384,7 +382,7 @@ def whitelisted_homeserver(destination): Args: destination (str) - """ + """ if _homeserver_whitelist: return _homeserver_whitelist.match(destination) diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index becf66dd86..fd3543ab04 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py
@@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args): def log_function(f): - """ Function decorator that logs every call to that function. - """ + """Function decorator that logs every call to that function.""" func_name = f.__name__ @wraps(f) diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index cbf0dbb871..a8cb49d5b4 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py
@@ -155,8 +155,7 @@ class InFlightGauge: self._registrations.setdefault(key, set()).add(callback) def unregister(self, key, callback): - """Registers that we've exited a block with labels `key`. - """ + """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) @@ -402,7 +401,9 @@ class PyPyGCStats: # Total time spent in GC: 0.073 # s.total_gc_time pypy_gc_time = CounterMetricFamily( - "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[], + "pypy_gc_time_seconds_total", + "Total time spent in PyPy GC", + labels=[], ) pypy_gc_time.add_metric([], s.total_gc_time / 1000) yield pypy_gc_time diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
index 734271e765..71320a1402 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py
@@ -216,7 +216,7 @@ class MetricsHandler(BaseHTTPRequestHandler): @classmethod def factory(cls, registry): """Returns a dynamic MetricsHandler class tied - to the passed registry. + to the passed registry. """ # This implementation relies on MetricsHandler.registry # (defined above and defaulted to REGISTRY). diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 70e0fa45d9..b56986d8e7 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -208,7 +208,8 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar return await maybe_awaitable(func(*args, **kwargs)) except Exception: logger.exception( - "Background process '%s' threw an exception", desc, + "Background process '%s' threw an exception", + desc, ) finally: _background_process_in_flight_count.labels(desc).dec() @@ -249,8 +250,7 @@ class BackgroundProcessLoggingContext(LoggingContext): self._proc = _BackgroundProcess(name, self) def start(self, rusage: "Optional[resource._RUsage]"): - """Log context has started running (again). - """ + """Log context has started running (again).""" super().start(rusage) @@ -261,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext): _background_processes_active_since_last_scrape.add(self._proc) def __exit__(self, type, value, traceback) -> None: - """Log context has finished. - """ + """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 401d577293..2e3b311c4a 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -275,7 +275,9 @@ class ModuleApi: redirect them directly if whitelisted). """ self._auth_handler._complete_sso_login( - registered_user_id, request, client_redirect_url, + registered_user_id, + request, + client_redirect_url, ) async def complete_sso_login_async( @@ -352,7 +354,10 @@ class ModuleApi: event, _, ) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event( - requester, event_dict, ratelimit=False, ignore_shadow_ban=True, + requester, + event_dict, + ratelimit=False, + ignore_shadow_ban=True, ) return event diff --git a/synapse/notifier.py b/synapse/notifier.py
index 0745899b48..1374aae490 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py
@@ -75,7 +75,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int: class _NotificationListener: - """ This represents a single client connection to the events stream. + """This represents a single client connection to the events stream. The events stream handler will have yielded to the deferred, so to notify the handler it is sufficient to resolve the deferred. """ @@ -119,7 +119,10 @@ class _NotifierUserStream: self.notify_deferred = ObservableDeferred(defer.Deferred()) def notify( - self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int, + self, + stream_key: str, + stream_id: Union[int, RoomStreamToken], + time_now_ms: int, ): """Notify any listeners for this user of a new event from an event source. @@ -140,7 +143,7 @@ class _NotifierUserStream: noify_deferred.callback(self.current_token) def remove(self, notifier: "Notifier"): - """ Remove this listener from all the indexes in the Notifier + """Remove this listener from all the indexes in the Notifier it knows about. """ @@ -186,7 +189,7 @@ class _PendingRoomEventEntry: class Notifier: - """ This class is responsible for notifying any listeners when there are + """This class is responsible for notifying any listeners when there are new events available for it. Primarily used from the /events stream. @@ -265,8 +268,7 @@ class Notifier: max_room_stream_token: RoomStreamToken, extra_users: Collection[UserID] = [], ): - """Unwraps event and calls `on_new_room_event_args`. - """ + """Unwraps event and calls `on_new_room_event_args`.""" self.on_new_room_event_args( event_pos=event_pos, room_id=event.room_id, @@ -341,7 +343,10 @@ class Notifier: if users or rooms: self.on_new_event( - "room_key", max_room_stream_token, users=users, rooms=rooms, + "room_key", + max_room_stream_token, + users=users, + rooms=rooms, ) self._on_updated_room_token(max_room_stream_token) @@ -392,7 +397,7 @@ class Notifier: users: Collection[Union[str, UserID]] = [], rooms: Collection[str] = [], ): - """ Used to inform listeners that something has happened event wise. + """Used to inform listeners that something has happened event wise. Will wake up all listeners for the given users and rooms. """ @@ -418,7 +423,9 @@ class Notifier: # Notify appservices self._notify_app_services_ephemeral( - stream_key, new_token, users, + stream_key, + new_token, + users, ) def on_new_replication_data(self) -> None: @@ -502,7 +509,7 @@ class Notifier: is_guest: bool = False, explicit_room_id: str = None, ) -> EventStreamResult: - """ For the given user and rooms, return any new events for them. If + """For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any new events to happen before returning. @@ -651,8 +658,7 @@ class Notifier: cb() def notify_remote_server_up(self, server: str): - """Notify any replication that a remote server has come back up - """ + """Notify any replication that a remote server has come back up""" # We call federation_sender directly rather than registering as a # callback as a) we already have a reference to it and b) it introduces # circular dependencies. diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6211506990..4d284de133 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py
@@ -495,7 +495,11 @@ BASE_APPEND_UNDERRIDE_RULES = [ "_id": "_message", } ], - "actions": ["notify", {"set_tweak": "highlight", "value": False}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, # XXX: this is going to fire for events which aren't m.room.messages # but are encrypted (e.g. m.call.*)... @@ -509,7 +513,11 @@ BASE_APPEND_UNDERRIDE_RULES = [ "_id": "_encrypted", } ], - "actions": ["notify", {"set_tweak": "highlight", "value": False}], + "actions": [ + "notify", + {"set_tweak": "sound", "value": "default"}, + {"set_tweak": "highlight", "value": False}, + ], }, { "rule_id": "global/underride/.im.vector.jitsi", diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 9018f9e20b..c016a83909 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -144,8 +144,7 @@ class BulkPushRuleEvaluator: @lru_cache() def _get_rules_for_room(self, room_id: str) -> "RulesForRoom": - """Get the current RulesForRoom object for the given room id - """ + """Get the current RulesForRoom object for the given room id""" # It's important that RulesForRoom gets added to self._get_rules_for_room.cache # before any lookup methods get called on it as otherwise there may be # a race if invalidate_all gets called (which assumes its in the cache) @@ -252,7 +251,9 @@ class BulkPushRuleEvaluator: # notified for this event. (This will then get handled when we persist # the event) await self.store.add_push_actions_to_staging( - event.event_id, actions_by_user, count_as_unread, + event.event_id, + actions_by_user, + count_as_unread, ) @@ -524,7 +525,7 @@ class RulesForRoom: class _Invalidation: # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules, # which means that it it is stored on the bulk_get_push_rules cache entry. In order - # to ensure that we don't accumulate lots of redunant callbacks on the cache entry, + # to ensure that we don't accumulate lots of redundant callbacks on the cache entry, # we need to ensure that two _Invalidation objects are "equal" if they refer to the # same `cache` and `room_id`. # diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 4ac1b31748..5fec2aaf5d 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py
@@ -116,8 +116,7 @@ class EmailPusher(Pusher): self._is_processing = True def _resume_processing(self) -> None: - """Used by tests to resume processing of events after pausing. - """ + """Used by tests to resume processing of events after pausing.""" assert self._is_processing self._is_processing = False self._start_processing() @@ -157,8 +156,10 @@ class EmailPusher(Pusher): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( - self.user_id, start, self.max_stream_ordering + unprocessed = ( + await self.store.get_unread_push_actions_for_user_in_range_for_email( + self.user_id, start, self.max_stream_ordering + ) ) soonest_due_at = None # type: Optional[int] @@ -222,12 +223,14 @@ class EmailPusher(Pusher): self, last_stream_ordering: int ) -> None: self.last_stream_ordering = last_stream_ordering - pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.email, - self.user_id, - last_stream_ordering, - self.clock.time_msec(), + pusher_still_exists = ( + await self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.email, + self.user_id, + last_stream_ordering, + self.clock.time_msec(), + ) ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -298,7 +301,8 @@ class EmailPusher(Pusher): current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS ) self.throttle_params[room_id] = ThrottleParams( - self.clock.time_msec(), new_throttle_ms, + self.clock.time_msec(), + new_throttle_ms, ) assert self.pusher_id is not None await self.store.set_throttle_params( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e048b0d59e..b9d3da2e0a 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -176,8 +176,10 @@ class HttpPusher(Pusher): Never call this directly: use _process which will only allow this to run once per pusher. """ - unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( - self.user_id, self.last_stream_ordering, self.max_stream_ordering + unprocessed = ( + await self.store.get_unread_push_actions_for_user_in_range_for_http( + self.user_id, self.last_stream_ordering, self.max_stream_ordering + ) ) logger.info( @@ -204,12 +206,14 @@ class HttpPusher(Pusher): http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( - self.app_id, - self.pushkey, - self.user_id, - self.last_stream_ordering, - self.clock.time_msec(), + pusher_still_exists = ( + await self.store.update_pusher_last_stream_ordering_and_success( + self.app_id, + self.pushkey, + self.user_id, + self.last_stream_ordering, + self.clock.time_msec(), + ) ) if not pusher_still_exists: # The pusher has been deleted while we were processing, so @@ -290,7 +294,8 @@ class HttpPusher(Pusher): # for sanity, we only remove the pushkey if it # was the one we actually sent... logger.warning( - ("Ignoring rejected pushkey %s because we didn't send it"), pk, + ("Ignoring rejected pushkey %s because we didn't send it"), + pk, ) else: logger.info("Pushkey %s was rejected: removing", pk) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 8a6dcff30d..d10201b6b3 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py
@@ -34,6 +34,7 @@ from synapse.push.presentable_names import ( descriptor_from_member_events, name_from_member_event, ) +from synapse.storage.state import StateFilter from synapse.types import StateMap, UserID from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client @@ -110,6 +111,7 @@ class Mailer: self.sendmail = self.hs.get_sendmail() self.store = self.hs.get_datastore() + self.state_store = self.hs.get_storage().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() self.storage = hs.get_storage() @@ -217,7 +219,17 @@ class Mailer: push_actions: Iterable[Dict[str, Any]], reason: Dict[str, Any], ) -> None: - """Send email regarding a user's room notifications""" + """ + Send email regarding a user's room notifications + + Params: + app_id: The application receiving the notification. + user_id: The user receiving the notification. + email_address: The email address receiving the notification. + push_actions: All outstanding notifications. + reason: The notification that was ready and is the cause of an email + being sent. + """ rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) notif_events = await self.store.get_events( @@ -241,7 +253,7 @@ class Mailer: except StoreError: user_display_name = user_id - async def _fetch_room_state(room_id): + async def _fetch_room_state(room_id: str) -> None: room_state = await self.store.get_current_state_ids(room_id) state_by_room[room_id] = room_state @@ -255,7 +267,7 @@ class Mailer: rooms = [] for r in rooms_in_order: - roomvars = await self.get_room_vars( + roomvars = await self._get_room_vars( r, user_id, notifs_by_room[r], notif_events, state_by_room[r] ) rooms.append(roomvars) @@ -271,7 +283,7 @@ class Mailer: # Only one room has new stuff room_id = list(notifs_by_room.keys())[0] - summary_text = await self.make_summary_text_single_room( + summary_text = await self._make_summary_text_single_room( room_id, notifs_by_room[room_id], state_by_room[room_id], @@ -279,13 +291,13 @@ class Mailer: user_id, ) else: - summary_text = await self.make_summary_text( + summary_text = await self._make_summary_text( notifs_by_room, state_by_room, notif_events, reason ) template_vars = { "user_display_name": user_display_name, - "unsubscribe_link": self.make_unsubscribe_link( + "unsubscribe_link": self._make_unsubscribe_link( user_id, app_id, email_address ), "summary_text": summary_text, @@ -349,7 +361,7 @@ class Mailer: ) ) - async def get_room_vars( + async def _get_room_vars( self, room_id: str, user_id: str, @@ -357,6 +369,20 @@ class Mailer: notif_events: Dict[str, EventBase], room_state_ids: StateMap[str], ) -> Dict[str, Any]: + """ + Generate the variables for notifications on a per-room basis. + + Args: + room_id: The room ID + user_id: The user receiving the notification. + notifs: The outstanding push actions for this room. + notif_events: The events related to the above notifications. + room_state_ids: The event IDs of the current room state. + + Returns: + A dictionary to be added to the template context. + """ + # Check if one of the notifs is an invite event for the user. is_invite = False for n in notifs: @@ -373,12 +399,12 @@ class Mailer: "hash": string_ordinal_total(room_id), # See sender avatar hash "notifs": [], "invite": is_invite, - "link": self.make_room_link(room_id), + "link": self._make_room_link(room_id), } # type: Dict[str, Any] if not is_invite: for n in notifs: - notifvars = await self.get_notif_vars( + notifvars = await self._get_notif_vars( n, user_id, notif_events[n["event_id"]], room_state_ids ) @@ -405,13 +431,26 @@ class Mailer: return room_vars - async def get_notif_vars( + async def _get_notif_vars( self, notif: Dict[str, Any], user_id: str, notif_event: EventBase, room_state_ids: StateMap[str], ) -> Dict[str, Any]: + """ + Generate the variables for a single notification. + + Args: + notif: The outstanding notification for this room. + user_id: The user receiving the notification. + notif_event: The event related to the above notification. + room_state_ids: The event IDs of the current room state. + + Returns: + A dictionary to be added to the template context. + """ + results = await self.store.get_events_around( notif["room_id"], notif["event_id"], @@ -420,7 +459,7 @@ class Mailer: ) ret = { - "link": self.make_notif_link(notif), + "link": self._make_notif_link(notif), "ts": notif["received_ts"], "messages": [], } @@ -431,22 +470,51 @@ class Mailer: the_events.append(notif_event) for event in the_events: - messagevars = await self.get_message_vars(notif, event, room_state_ids) + messagevars = await self._get_message_vars(notif, event, room_state_ids) if messagevars is not None: ret["messages"].append(messagevars) return ret - async def get_message_vars( + async def _get_message_vars( self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str] ) -> Optional[Dict[str, Any]]: + """ + Generate the variables for a single event, if possible. + + Args: + notif: The outstanding notification for this room. + event: The event under consideration. + room_state_ids: The event IDs of the current room state. + + Returns: + A dictionary to be added to the template context, or None if the + event cannot be processed. + """ if event.type != EventTypes.Message and event.type != EventTypes.Encrypted: return None - sender_state_event_id = room_state_ids[("m.room.member", event.sender)] - sender_state_event = await self.store.get_event(sender_state_event_id) - sender_name = name_from_member_event(sender_state_event) - sender_avatar_url = sender_state_event.content.get("avatar_url") + # Get the sender's name and avatar from the room state. + type_state_key = ("m.room.member", event.sender) + sender_state_event_id = room_state_ids.get(type_state_key) + if sender_state_event_id: + sender_state_event = await self.store.get_event( + sender_state_event_id + ) # type: Optional[EventBase] + else: + # Attempt to check the historical state for the room. + historical_state = await self.state_store.get_state_for_event( + event.event_id, StateFilter.from_types((type_state_key,)) + ) + sender_state_event = historical_state.get(type_state_key) + + if sender_state_event: + sender_name = name_from_member_event(sender_state_event) + sender_avatar_url = sender_state_event.content.get("avatar_url") + else: + # No state could be found, fallback to the MXID. + sender_name = event.sender + sender_avatar_url = None # 'hash' for deterministically picking default images: use # sender_hash % the number of default images to choose from @@ -471,18 +539,25 @@ class Mailer: ret["msgtype"] = msgtype if msgtype == "m.text": - self.add_text_message_vars(ret, event) + self._add_text_message_vars(ret, event) elif msgtype == "m.image": - self.add_image_message_vars(ret, event) + self._add_image_message_vars(ret, event) if "body" in event.content: ret["body_text_plain"] = event.content["body"] return ret - def add_text_message_vars( + def _add_text_message_vars( self, messagevars: Dict[str, Any], event: EventBase ) -> None: + """ + Potentially add a sanitised message body to the message variables. + + Args: + messagevars: The template context to be modified. + event: The event under consideration. + """ msgformat = event.content.get("format") messagevars["format"] = msgformat @@ -495,16 +570,20 @@ class Mailer: elif body: messagevars["body_text_html"] = safe_text(body) - def add_image_message_vars( + def _add_image_message_vars( self, messagevars: Dict[str, Any], event: EventBase ) -> None: """ Potentially add an image URL to the message variables. + + Args: + messagevars: The template context to be modified. + event: The event under consideration. """ if "url" in event.content: messagevars["image_url"] = event.content["url"] - async def make_summary_text_single_room( + async def _make_summary_text_single_room( self, room_id: str, notifs: List[Dict[str, Any]], @@ -517,7 +596,7 @@ class Mailer: Args: room_id: The ID of the room. - notifs: The notifications for this room. + notifs: The push actions for this room. room_state_ids: The state map for the room. notif_events: A map of event ID -> notification event. user_id: The user receiving the notification. @@ -600,11 +679,11 @@ class Mailer: "app": self.app_name, } - return await self.make_summary_text_from_member_events( + return await self._make_summary_text_from_member_events( room_id, notifs, room_state_ids, notif_events ) - async def make_summary_text( + async def _make_summary_text( self, notifs_by_room: Dict[str, List[Dict[str, Any]]], room_state_ids: Dict[str, StateMap[str]], @@ -615,7 +694,7 @@ class Mailer: Make a summary text for the email when multiple rooms have notifications. Args: - notifs_by_room: A map of room ID to the notifications for that room. + notifs_by_room: A map of room ID to the push actions for that room. room_state_ids: A map of room ID to the state map for that room. notif_events: A map of event ID -> notification event. reason: The reason this notification is being sent. @@ -632,11 +711,11 @@ class Mailer: } room_id = reason["room_id"] - return await self.make_summary_text_from_member_events( + return await self._make_summary_text_from_member_events( room_id, notifs_by_room[room_id], room_state_ids[room_id], notif_events ) - async def make_summary_text_from_member_events( + async def _make_summary_text_from_member_events( self, room_id: str, notifs: List[Dict[str, Any]], @@ -648,7 +727,7 @@ class Mailer: Args: room_id: The ID of the room. - notifs: The notifications for this room. + notifs: The push actions for this room. room_state_ids: The state map for the room. notif_events: A map of event ID -> notification event. @@ -657,14 +736,45 @@ class Mailer: """ # If the room doesn't have a name, say who the messages # are from explicitly to avoid, "messages in the Bob room" - sender_ids = {notif_events[n["event_id"]].sender for n in notifs} - member_events = await self.store.get_events( - [room_state_ids[("m.room.member", s)] for s in sender_ids] - ) + # Find the latest event ID for each sender, note that the notifications + # are already in descending received_ts. + sender_ids = {} + for n in notifs: + sender = notif_events[n["event_id"]].sender + if sender not in sender_ids: + sender_ids[sender] = n["event_id"] + + # Get the actual member events (in order to calculate a pretty name for + # the room). + member_event_ids = [] + member_events = {} + for sender_id, event_id in sender_ids.items(): + type_state_key = ("m.room.member", sender_id) + sender_state_event_id = room_state_ids.get(type_state_key) + if sender_state_event_id: + member_event_ids.append(sender_state_event_id) + else: + # Attempt to check the historical state for the room. + historical_state = await self.state_store.get_state_for_event( + event_id, StateFilter.from_types((type_state_key,)) + ) + sender_state_event = historical_state.get(type_state_key) + if sender_state_event: + member_events[event_id] = sender_state_event + member_events.update(await self.store.get_events(member_event_ids)) + + if not member_events: + # No member events were found! Maybe the room is empty? + # Fallback to the room ID (note that if there was a room name this + # would already have been used previously). + return self.email_subjects.messages_in_room % { + "room": room_id, + "app": self.app_name, + } # There was a single sender. - if len(sender_ids) == 1: + if len(member_events) == 1: return self.email_subjects.messages_from_person % { "person": descriptor_from_member_events(member_events.values()), "app": self.app_name, @@ -676,7 +786,16 @@ class Mailer: "app": self.app_name, } - def make_room_link(self, room_id: str) -> str: + def _make_room_link(self, room_id: str) -> str: + """ + Generate a link to open a room in the web client. + + Args: + room_id: The room ID to generate a link to. + + Returns: + A link to open a room in the web client. + """ if self.hs.config.email_riot_base_url: base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) elif self.app_name == "Vector": @@ -686,7 +805,16 @@ class Mailer: base_url = "https://matrix.to/#" return "%s/%s" % (base_url, room_id) - def make_notif_link(self, notif: Dict[str, str]) -> str: + def _make_notif_link(self, notif: Dict[str, str]) -> str: + """ + Generate a link to open an event in the web client. + + Args: + notif: The notification to generate a link for. + + Returns: + A link to open the notification in the web client. + """ if self.hs.config.email_riot_base_url: return "%s/#/room/%s/%s" % ( self.hs.config.email_riot_base_url, @@ -702,9 +830,20 @@ class Mailer: else: return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) - def make_unsubscribe_link( + def _make_unsubscribe_link( self, user_id: str, app_id: str, email_address: str ) -> str: + """ + Generate a link to unsubscribe from email notifications. + + Args: + user_id: The user receiving the notification. + app_id: The application receiving the notification. + email_address: The email address receiving the notification. + + Returns: + A link to unsubscribe from email notifications. + """ params = { "access_token": self.macaroon_gen.generate_delete_pusher_token(user_id), "app_id": app_id, diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index eed16dbfb5..343ab30bed 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -62,7 +62,7 @@ class PusherPool: self.store = self.hs.get_datastore() self.clock = self.hs.get_clock() - self._account_validity = hs.config.account_validity + self._account_validity_enabled = hs.config.account_validity_enabled # We shard the handling of push notifications by user ID. self._pusher_shard_config = hs.config.push.pusher_shard_config @@ -78,8 +78,7 @@ class PusherPool: self.pushers = {} # type: Dict[str, Dict[str, Pusher]] def start(self) -> None: - """Starts the pushers off in a background process. - """ + """Starts the pushers off in a background process.""" if not self._should_start_pushers: logger.info("Not starting pushers because they are disabled in the config") return @@ -225,7 +224,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) @@ -255,7 +254,7 @@ class PusherPool: for u in users_affected: # Don't push if the user account has expired - if self._account_validity.enabled: + if self._account_validity_enabled: expired = await self.store.is_account_expired( u, self.clock.time_msec() ) @@ -297,8 +296,7 @@ class PusherPool: return pusher async def _start_pushers(self) -> None: - """Start all the pushers - """ + """Start all the pushers""" pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the @@ -335,7 +333,8 @@ class PusherPool: return None except Exception: logger.exception( - "Couldn't start pusher id %i: caught Exception", pusher_config.id, + "Couldn't start pusher id %i: caught Exception", + pusher_config.id, ) return None diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index bfd46a3730..0857efbe71 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py
@@ -86,8 +86,12 @@ REQUIREMENTS = [ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], - # we use execute_values with the fetch param, which arrived in psycopg 2.8. - "postgres": ["psycopg2>=2.8"], + "postgres": [ + # we use execute_values with the fetch param, which arrived in psycopg 2.8. + "psycopg2>=2.8 ; platform_python_implementation != 'PyPy'", + "psycopg2cffi>=2.8 ; platform_python_implementation == 'PyPy'", + "psycopg2cffi-compat==1.1 ; platform_python_implementation == 'PyPy'", + ], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ @@ -112,6 +116,8 @@ CONDITIONAL_REQUIREMENTS = { "redis": ["txredisapi>=1.4.7", "hiredis"], } +CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"] + ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 288727a566..8a3f113e76 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], self._check_auth_and_handle, self.__class__.__name__, + method, + [pattern], + self._check_auth_and_handle, + self.__class__.__name__, ) def _check_auth_and_handle(self, request, **kwargs): diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
index 52d32528ee..60899b6ad6 100644 --- a/synapse/replication/http/account_data.py +++ b/synapse/replication/http/account_data.py
@@ -175,7 +175,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): return {} async def _handle_request(self, request, user_id, room_id, tag): - max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + max_stream_id = await self.handler.remove_tag_from_room( + user_id, + room_id, + tag, + ) return 200, {"max_stream_id": max_stream_id} diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 84e002f934..ce52613956 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py
@@ -98,6 +98,73 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): return 200, {"event_id": event_id, "stream_id": stream_id} +class ReplicationRemoteKnockRestServlet(ReplicationEndpoint): + """Perform a remote knock for the given user on the given room + + Request format: + + POST /_synapse/replication/remote_knock/:room_id/:user_id + + { + "requester": ..., + "remote_room_hosts": [...], + "content": { ... } + } + """ + + NAME = "remote_knock" + PATH_ARGS = ("room_id", "user_id") + + def __init__(self, hs): + super().__init__(hs) + + self.federation_handler = hs.get_federation_handler() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload( # type: ignore + requester: Requester, + room_id: str, + user_id: str, + remote_room_hosts: List[str], + content: JsonDict, + ): + """ + Args: + requester: The user making the request, according to the access token. + room_id: The ID of the room to knock on. + user_id: The ID of the knocking user. + remote_room_hosts: Servers to try and send the knock via. + content: The event content to use for the knock event. + """ + return { + "requester": requester.serialize(), + "remote_room_hosts": remote_room_hosts, + "content": content, + } + + async def _handle_request( # type: ignore + self, request: Request, room_id: str, user_id: str, + ): + content = parse_json_object_from_request(request) + + remote_room_hosts = content["remote_room_hosts"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + request.requester = requester + + logger.debug("remote_knock: %s on room: %s", user_id, room_id) + + event_id, stream_id = await self.federation_handler.do_knock( + remote_room_hosts, room_id, user_id, event_content + ) + + return 200, {"event_id": event_id, "stream_id": stream_id} + + class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): """Rejects an out-of-band invite we have received from a remote server @@ -160,7 +227,74 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # hopefully we're now on the master, so this won't recurse! event_id, stream_id = await self.member_handler.remote_reject_invite( - invite_event_id, txn_id, requester, event_content, + invite_event_id, + txn_id, + requester, + event_content, + ) + + return 200, {"event_id": event_id, "stream_id": stream_id} + + +class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint): + """Rescinds a local knock made on a remote room + + Request format: + + POST /_synapse/replication/remote_rescind_knock/:event_id + + { + "txn_id": ..., + "requester": ..., + "content": { ... } + } + """ + + NAME = "remote_rescind_knock" + PATH_ARGS = ("knock_event_id",) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.member_handler = hs.get_room_member_handler() + + @staticmethod + async def _serialize_payload( # type: ignore + knock_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ): + """ + Args: + knock_event_id: The ID of the knock to be rescinded. + txn_id: An optional transaction ID supplied by the client. + requester: The user making the rescind request, according to the access token. + content: The content to include in the rescind event. + """ + return { + "txn_id": txn_id, + "requester": requester.serialize(), + "content": content, + } + + async def _handle_request( # type: ignore + self, request: Request, knock_event_id: str, + ): + content = parse_json_object_from_request(request) + + txn_id = content["txn_id"] + event_content = content["content"] + + requester = Requester.deserialize(self.store, content["requester"]) + + request.requester = requester + + # hopefully we're now on the master, so this won't recurse! + event_id, stream_id = await self.member_handler.remote_rescind_knock( + knock_event_id, txn_id, requester, event_content, ) return 200, {"event_id": event_id, "stream_id": stream_id} diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 7b12ec9060..d005f38767 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__) class ReplicationRegisterServlet(ReplicationEndpoint): - """Register a new user - """ + """Register a new user""" NAME = "register_user" PATH_ARGS = ("user_id",) @@ -97,8 +96,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): - """Run any post registration actions - """ + """Run any post registration actions""" NAME = "post_register" PATH_ARGS = ("user_id",) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index ac532ed588..0a9da79c32 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitrary (often timestamp) - """ + """Sent by either side as a keep alive. The data is arbitrary (often timestamp)""" NAME = "PING" diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
index 34fa3ff5b3..d89a36f25a 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py
@@ -60,8 +60,7 @@ class ExternalCache: return self._redis_connection is not None async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: - """Add the key/value to the named cache, with the expiry time given. - """ + """Add the key/value to the named cache, with the expiry time given.""" if self._redis_connection is None: return @@ -76,13 +75,14 @@ class ExternalCache: return await make_deferred_yieldable( self._redis_connection.set( - self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + self._get_redis_key(cache_name, key), + encoded_value, + pexpire=expiry_ms, ) ) async def get(self, cache_name: str, key: str) -> Optional[Any]: - """Look up a key/value in the named cache. - """ + """Look up a key/value in the named cache.""" if self._redis_connection is None: return None diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 8ea8dcd587..d1d00c3717 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -303,7 +303,9 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory, + hs.config.redis.redis_host, + hs.config.redis.redis_port, + self._factory, ) else: client_name = hs.get_instance_name() @@ -313,13 +315,11 @@ class ReplicationCommandHandler: hs.get_reactor().connectTCP(host, port, self._factory) def get_streams(self) -> Dict[str, Stream]: - """Get a map from stream name to all streams. - """ + """Get a map from stream name to all streams.""" return self._streams def get_streams_to_replicate(self) -> List[Stream]: - """Get a list of streams that this instances replicates. - """ + """Get a list of streams that this instances replicates.""" return self._streams_to_replicate def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): @@ -340,7 +340,10 @@ class ReplicationCommandHandler: current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand( - stream.NAME, self._instance_name, current_token, current_token, + stream.NAME, + self._instance_name, + current_token, + current_token, ) ) @@ -592,8 +595,7 @@ class ReplicationCommandHandler: self.send_command(cmd, ignore_conn=conn) def new_connection(self, connection: AbstractConnection): - """Called when we have a new connection. - """ + """Called when we have a new connection.""" self._connections.append(connection) # If we are connected to replication as a client (rather than a server) @@ -620,8 +622,7 @@ class ReplicationCommandHandler: ) def lost_connection(self, connection: AbstractConnection): - """Called when a connection is closed/lost. - """ + """Called when a connection is closed/lost.""" # we no longer need _streams_by_connection for this connection. streams = self._streams_by_connection.pop(connection, None) if streams: @@ -678,15 +679,13 @@ class ReplicationCommandHandler: def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int ): - """Poke the master that a user has started/stopped syncing. - """ + """Poke the master that a user has started/stopped syncing.""" self.send_command( UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) ) def send_remove_pusher(self, app_id: str, push_key: str, user_id: str): - """Poke the master to remove a pusher for a user - """ + """Poke the master to remove a pusher for a user""" cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) @@ -699,8 +698,7 @@ class ReplicationCommandHandler: device_id: str, last_seen: int, ): - """Tell the master that the user made a request. - """ + """Tell the master that the user made a request.""" cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) self.send_command(cmd) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 804da994ea..e0b4ad314d 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -222,8 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.send_error("ping timeout") def lineReceived(self, line: bytes): - """Called when we've received a line - """ + """Called when we've received a line""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_line(line) @@ -299,8 +298,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.on_connection_closed() def send_error(self, error_string, *args): - """Send an error to remote and close the connection. - """ + """Send an error to remote and close the connection.""" self.send_command(ErrorCommand(error_string % args)) self.close() @@ -341,8 +339,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_sent_command = self.clock.time_msec() def _queue_command(self, cmd): - """Queue the command until the connection is ready to write to again. - """ + """Queue the command until the connection is ready to write to again.""" logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) @@ -355,8 +352,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.close() def _send_pending_commands(self): - """Send any queued commandes - """ + """Send any queued commandes""" pending = self.pending_commands self.pending_commands = [] for cmd in pending: @@ -380,8 +376,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.state = ConnectionStates.PAUSED def resumeProducing(self): - """The remote has caught up after we started buffering! - """ + """The remote has caught up after we started buffering!""" logger.info("[%s] Resume producing", self.id()) self.state = ConnectionStates.ESTABLISHED self._send_pending_commands() @@ -440,8 +435,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): return "%s-%s" % (self.name, self.conn_id) def lineLengthExceeded(self, line): - """Called when we receive a line that is above the maximum line length - """ + """Called when we receive a line that is above the maximum line length""" self.send_error("Line length exceeded") @@ -495,21 +489,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_error("Wrong remote") def replicate(self): - """Send the subscription request to the server - """ + """Send the subscription request to the server""" logger.info("[%s] Subscribing to replication streams", self.id()) self.send_command(ReplicateCommand()) class AbstractConnection(abc.ABC): - """An interface for replication connections. - """ + """An interface for replication connections.""" @abc.abstractmethod def send_command(self, cmd: Command): - """Send the command down the connection - """ + """Send the command down the connection""" pass diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index fdd087683b..0e6155cf53 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -15,8 +15,9 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast +import attr import txredisapi from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -42,6 +43,24 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +T = TypeVar("T") +V = TypeVar("V") + + +@attr.s +class ConstantProperty(Generic[T, V]): + """A descriptor that returns the given constant, ignoring attempts to set + it. + """ + + constant = attr.ib() # type: V + + def __get__(self, obj: Optional[T], objtype: Type[T] = None) -> V: + return self.constant + + def __set__(self, obj: Optional[T], value: V): + pass + class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): """Connection to redis subscribed to replication stream. @@ -104,8 +123,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str): - """Received a message from redis. - """ + """Received a message from redis.""" with PreserveLoggingContext(self._logging_context): self._parse_and_dispatch_message(message) @@ -118,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): cmd = parse_command_from_line(message) except Exception: logger.exception( - "Failed to parse replication line: %r", message, + "Failed to parse replication line: %r", + message, ) return @@ -195,6 +214,10 @@ class SynapseRedisFactory(txredisapi.RedisFactory): we detect dead connections. """ + # We want to *always* retry connecting, txredisapi will stop if there is a + # failure during certain operations, e.g. during AUTH. + continueTrying = cast(bool, ConstantProperty(True)) + def __init__( self, hs: "HomeServer", @@ -243,7 +266,6 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """ maxDelay = 5 - continueTrying = True protocol = RedisSubscriber def __init__( diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 1d4ceac0f1..2018f9f29e 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -36,8 +36,7 @@ logger = logging.getLogger(__name__) class ReplicationStreamProtocolFactory(Factory): - """Factory for new replication connections. - """ + """Factory for new replication connections.""" def __init__(self, hs): self.command_handler = hs.get_tcp_replication() @@ -181,7 +180,8 @@ class ReplicationStreamer: raise logger.debug( - "Sending %d updates", len(updates), + "Sending %d updates", + len(updates), ) if updates: diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 61b282ab2d..38809b5b7c 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -183,7 +183,10 @@ class Stream: return [], upto_token, False updates, upto_token, limited = await self.update_function( - instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, + instance_name, + from_token, + upto_token, + _STREAM_UPDATE_TARGET_ROW_COUNT, ) return updates, upto_token, limited @@ -339,8 +342,7 @@ class ReceiptsStream(Stream): class PushRulesStream(Stream): - """A user has changed their push rules - """ + """A user has changed their push rules""" PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str @@ -362,8 +364,7 @@ class PushRulesStream(Stream): class PushersStream(Stream): - """A user has added/changed/removed a pusher - """ + """A user has added/changed/removed a pusher""" PushersStreamRow = namedtuple( "PushersStreamRow", @@ -416,8 +417,7 @@ class CachesStream(Stream): class PublicRoomsStream(Stream): - """The public rooms list changed - """ + """The public rooms list changed""" PublicRoomsStreamRow = namedtuple( "PublicRoomsStreamRow", @@ -463,8 +463,7 @@ class DeviceListsStream(Stream): class ToDeviceStream(Stream): - """New to_device messages for a client - """ + """New to_device messages for a client""" ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str @@ -481,8 +480,7 @@ class ToDeviceStream(Stream): class TagAccountDataStream(Stream): - """Someone added/removed a tag for a room - """ + """Someone added/removed a tag for a room""" TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict @@ -501,8 +499,7 @@ class TagAccountDataStream(Stream): class AccountDataStream(Stream): - """Global or per room account data was changed - """ + """Global or per room account data was changed""" AccountDataStreamRow = namedtuple( "AccountDataStream", @@ -589,8 +586,7 @@ class GroupServerStream(Stream): class UserSignatureStream(Stream): - """A user has signed their own device with their user-signing key - """ + """A user has signed their own device with their user-signing key""" UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 86a62b71eb..fa5e37ba7b 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -113,8 +113,7 @@ TypeToRow = {Row.TypeId: Row for Row in _EventRows} class EventsStream(Stream): - """We received a new event, or an event went from being an outlier to not - """ + """We received a new event, or an event went from being an outlier to not""" NAME = "events" diff --git a/synapse/res/templates/account_previously_renewed.html b/synapse/res/templates/account_previously_renewed.html new file mode 100644
index 0000000000..b751359bdf --- /dev/null +++ b/synapse/res/templates/account_previously_renewed.html
@@ -0,0 +1 @@ +<html><body>Your account is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html> diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
index 894da030af..e8c0f52f05 100644 --- a/synapse/res/templates/account_renewed.html +++ b/synapse/res/templates/account_renewed.html
@@ -1 +1 @@ -<html><body>Your account has been successfully renewed.</body><html> +<html><body>Your account has been successfully renewed and is valid until {{ expiration_ts|format_ts("%d-%m-%Y") }}.</body><html> diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css
index 46b309ea4e..338214f5d0 100644 --- a/synapse/res/templates/sso.css +++ b/synapse/res/templates/sso.css
@@ -1,16 +1,26 @@ -body { +body, input, select, textarea { font-family: "Inter", "Helvetica", "Arial", sans-serif; font-size: 14px; color: #17191C; } -header { +header, footer { max-width: 480px; width: 100%; margin: 24px auto; text-align: center; } +@media screen and (min-width: 800px) { + header { + margin-top: 90px; + } +} + +header { + min-height: 60px; +} + header p { color: #737D8C; line-height: 24px; @@ -20,6 +30,10 @@ h1 { font-size: 24px; } +a { + color: #418DED; +} + .error_page h1 { color: #FE2928; } @@ -47,6 +61,9 @@ main { .primary-button { border: none; + -webkit-appearance: none; + -moz-appearance: none; + appearance: none; text-decoration: none; padding: 12px; color: white; @@ -63,8 +80,17 @@ main { .profile { display: flex; + flex-direction: column; + align-items: center; justify-content: center; - margin: 24px 0; + margin: 24px; + padding: 13px; + border: 1px solid #E9ECF1; + border-radius: 4px; +} + +.profile.with-avatar { + margin-top: 42px; /* (36px / 2) + 24px*/ } .profile .avatar { @@ -72,17 +98,32 @@ main { height: 36px; border-radius: 100%; display: block; - margin-right: 8px; + margin-top: -32px; + margin-bottom: 8px; } .profile .display-name { font-weight: bold; margin-bottom: 4px; + font-size: 15px; + line-height: 18px; } .profile .user-id { color: #737D8C; + font-size: 12px; + line-height: 12px; } -.profile .display-name, .profile .user-id { - line-height: 18px; +footer { + margin-top: 80px; } + +footer svg { + display: block; + width: 46px; + margin: 0px auto 12px auto; +} + +footer p { + color: #737D8C; +} \ No newline at end of file diff --git a/synapse/res/templates/sso_account_deactivated.html b/synapse/res/templates/sso_account_deactivated.html
index 50a0979c2f..c3e4deed93 100644 --- a/synapse/res/templates/sso_account_deactivated.html +++ b/synapse/res/templates/sso_account_deactivated.html
@@ -20,5 +20,6 @@ administrator. </p> </header> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index 36850a2d6a..f4fdc40b22 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html
@@ -1,12 +1,29 @@ <!DOCTYPE html> <html lang="en"> <head> - <title>Synapse Login</title> + <title>Create your account</title> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, user-scalable=no"> + <script type="text/javascript"> + let wasKeyboard = false; + document.addEventListener("mousedown", function() { wasKeyboard = false; }); + document.addEventListener("keydown", function() { wasKeyboard = true; }); + document.addEventListener("focusin", function() { + if (wasKeyboard) { + document.body.classList.add("keyboard-focus"); + } else { + document.body.classList.remove("keyboard-focus"); + } + }); + </script> <style type="text/css"> {% include "sso.css" without context %} + body.keyboard-focus :focus, body.keyboard-focus .username_input:focus-within { + outline: 3px solid #17191C; + outline-offset: 4px; + } + .username_input { display: flex; border: 2px solid #418DED; @@ -33,11 +50,12 @@ .username_input label { position: absolute; - top: -8px; + top: -5px; left: 14px; - font-size: 80%; + font-size: 10px; + line-height: 10px; background: white; - padding: 2px; + padding: 0 2px; } .username_input input { @@ -47,6 +65,13 @@ border: none; } + /* only clear the outline if we know it will be shown on the parent div using :focus-within */ + @supports selector(:focus-within) { + .username_input input { + outline: none !important; + } + } + .username_input div { color: #8D99A5; } @@ -65,6 +90,7 @@ .idp-pick-details .idp-detail { border-top: 1px solid #E9ECF1; padding: 12px; + display: block; } .idp-pick-details .check-row { display: flex; @@ -117,43 +143,44 @@ </div> <output for="username_input" id="field-username-output"></output> <input type="submit" value="Continue" class="primary-button"> - {% if user_attributes %} + {% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %} <section class="idp-pick-details"> <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2> {% if user_attributes.avatar_url %} - <div class="idp-detail idp-avatar"> + <label class="idp-detail idp-avatar" for="idp-avatar"> <div class="check-row"> - <label for="idp-avatar" class="name">Avatar</label> - <label for="idp-avatar" class="use">Use</label> + <span class="name">Avatar</span> + <span class="use">Use</span> <input type="checkbox" name="use_avatar" id="idp-avatar" value="true" checked> </div> <img src="{{ user_attributes.avatar_url }}" class="avatar" /> - </div> + </label> {% endif %} {% if user_attributes.display_name %} - <div class="idp-detail"> + <label class="idp-detail" for="idp-displayname"> <div class="check-row"> - <label for="idp-displayname" class="name">Display name</label> - <label for="idp-displayname" class="use">Use</label> + <span class="name">Display name</span> + <span class="use">Use</span> <input type="checkbox" name="use_display_name" id="idp-displayname" value="true" checked> </div> <p class="idp-value">{{ user_attributes.display_name }}</p> - </div> + </label> {% endif %} {% for email in user_attributes.emails %} - <div class="idp-detail"> + <label class="idp-detail" for="idp-email{{ loop.index }}"> <div class="check-row"> - <label for="idp-email{{ loop.index }}" class="name">E-mail</label> - <label for="idp-email{{ loop.index }}" class="use">Use</label> + <span class="name">E-mail</span> + <span class="use">Use</span> <input type="checkbox" name="use_email" id="idp-email{{ loop.index }}" value="{{ email }}" checked> </div> <p class="idp-value">{{ email }}</p> - </div> + </label> {% endfor %} </section> {% endif %} </form> </main> + {% include "sso_footer.html" without context %} <script type="text/javascript"> {% include "sso_auth_account_details.js" without context %} </script> diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
index c9bd4bef20..c061698a21 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -12,7 +12,7 @@ <header> <h1>That doesn't look right</h1> <p> - <strong>We were unable to validate your {{ server_name }} account</strong> + <strong>We were unable to validate your {{ server_name }} account</strong> via single&nbsp;sign&#8209;on&nbsp;(SSO), because the SSO Identity Provider returned different details than when you logged in. </p> @@ -21,5 +21,6 @@ the Identity Provider as when you log into your account. </p> </header> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html
index 2099c2f1f8..f9d0456f0a 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html
@@ -2,7 +2,7 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>Authentication</title> + <title>Confirm it's you</title> <meta name="viewport" content="width=device-width, user-scalable=no"> <style type="text/css"> {% include "sso.css" without context %} @@ -24,5 +24,6 @@ Continue with {{ idp.idp_name }} </a> </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html
index 3b975d7219..1ed3967e87 100644 --- a/synapse/res/templates/sso_auth_success.html +++ b/synapse/res/templates/sso_auth_success.html
@@ -23,5 +23,6 @@ application. </p> </header> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html
index b223ca0f56..472309c350 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html
@@ -38,6 +38,7 @@ <p>{{ error }}</p> </div> </header> + {% include "sso_footer.html" without context %} <script type="text/javascript"> // Error handling to support Auth0 errors that we might get through a GET request diff --git a/synapse/res/templates/sso_footer.html b/synapse/res/templates/sso_footer.html new file mode 100644
index 0000000000..588a3d508d --- /dev/null +++ b/synapse/res/templates/sso_footer.html
@@ -0,0 +1,19 @@ +<footer> + <svg role="img" aria-label="[Matrix logo]" viewBox="0 0 200 85" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"> + <g id="parent" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd"> + <g id="child" transform="translate(-122.000000, -6.000000)" fill="#000000" fill-rule="nonzero"> + <g id="matrix-logo" transform="translate(122.000000, 6.000000)"> + <polygon id="left-bracket" points="2.24708861 1.93811009 2.24708861 82.7268844 8.10278481 82.7268844 8.10278481 84.6652459 0 84.6652459 0 0 8.10278481 0 8.10278481 1.93811009"></polygon> + <path d="M24.8073418,27.5493174 L24.8073418,31.6376991 L24.924557,31.6376991 C26.0227848,30.0814294 27.3455696,28.8730642 28.8951899,28.0163743 C30.4437975,27.1611927 32.2189873,26.7318422 34.218481,26.7318422 C36.1394937,26.7318422 37.8946835,27.102622 39.4825316,27.8416679 C41.0708861,28.5819706 42.276962,29.8856073 43.1005063,31.7548404 C44.0017722,30.431345 45.2270886,29.2629486 46.7767089,28.2506569 C48.3253165,27.2388679 50.158481,26.7318422 52.2764557,26.7318422 C53.8843038,26.7318422 55.3736709,26.9269101 56.7473418,27.3162917 C58.1189873,27.7056734 59.295443,28.3285835 60.2759494,29.185022 C61.255443,30.0422147 62.02,31.1615927 62.5701266,32.5426532 C63.1187342,33.9262275 63.3936709,35.5898349 63.3936709,37.5372459 L63.3936709,57.7443688 L55.0410127,57.7441174 L55.0410127,40.6319376 C55.0410127,39.6201486 55.0020253,38.6661761 54.9232911,37.7700202 C54.8440506,36.8751211 54.6293671,36.0968606 54.2764557,35.4339817 C53.9232911,34.772611 53.403038,34.2464807 52.7177215,33.8568477 C52.0313924,33.4689743 51.0997468,33.2731523 49.9235443,33.2731523 C48.7473418,33.2731523 47.7962025,33.4983853 47.0706329,33.944578 C46.344557,34.393033 45.7764557,34.9774826 45.3650633,35.6969211 C44.9534177,36.4181193 44.6787342,37.2353431 44.5417722,38.150855 C44.4037975,39.0653615 44.3356962,39.9904257 44.3356962,40.9247908 L44.3356962,57.7443688 L35.9835443,57.7443688 L35.9835443,40.8079009 C35.9835443,39.9124991 35.963038,39.0263982 35.9253165,38.150855 C35.8853165,37.2743064 35.7192405,36.4666349 35.424557,35.7263321 C35.1303797,34.9872862 34.64,34.393033 33.9539241,33.944578 C33.2675949,33.4983853 32.2579747,33.2731523 30.9248101,33.2731523 C30.5321519,33.2731523 30.0126582,33.3608826 29.3663291,33.5365945 C28.7192405,33.7118037 28.0913924,34.0433688 27.4840506,34.5292789 C26.875443,35.0164459 26.3564557,35.7172826 25.9250633,36.6315376 C25.4934177,37.5470495 25.2779747,38.7436 25.2779747,40.2229486 L25.2779747,57.7441174 L16.9260759,57.7443688 L16.9260759,27.5493174 L24.8073418,27.5493174 Z" id="m"></path> + <path d="M68.7455696,31.9886202 C69.6075949,30.7033339 70.7060759,29.672189 72.0397468,28.8926716 C73.3724051,28.1141596 74.8716456,27.5596239 76.5387342,27.2283101 C78.2050633,26.8977505 79.8817722,26.7315908 81.5678481,26.7315908 C83.0974684,26.7315908 84.6458228,26.8391798 86.2144304,27.0525982 C87.7827848,27.2675248 89.2144304,27.6865688 90.5086076,28.3087248 C91.8025316,28.9313835 92.8610127,29.7983798 93.6848101,30.9074514 C94.5083544,32.0170257 94.92,33.4870734 94.92,35.3173431 L94.92,51.026844 C94.92,52.3913138 94.998481,53.6941963 95.1556962,54.9400165 C95.3113924,56.1865908 95.5863291,57.120956 95.9787342,57.7436147 L87.5091139,57.7436147 C87.3518987,57.276055 87.2240506,56.7996972 87.1265823,56.3125303 C87.0278481,55.8266202 86.9592405,55.3301523 86.9207595,54.8236294 C85.5873418,56.1865908 84.0182278,57.1405633 82.2156962,57.6857982 C80.4113924,58.2295248 78.5683544,58.503022 76.6860759,58.503022 C75.2346835,58.503022 73.8817722,58.3275615 72.6270886,57.9776459 C71.3718987,57.6269761 70.2744304,57.082244 69.3334177,56.3411872 C68.3921519,55.602644 67.656962,54.6680275 67.1275949,53.5390972 C66.5982278,52.410167 66.3331646,51.065556 66.3331646,49.5087835 C66.3331646,47.7961578 66.6367089,46.384178 67.2455696,45.2756092 C67.8529114,44.1652807 68.6367089,43.2799339 69.5987342,42.6173064 C70.5589873,41.9556844 71.6567089,41.4592165 72.8924051,41.1284055 C74.1273418,40.7978459 75.3721519,40.5356606 76.6270886,40.3398385 C77.8820253,40.1457761 79.116962,39.9896716 80.3329114,39.873033 C81.5483544,39.7558917 82.6270886,39.5804312 83.5681013,39.3469028 C84.5093671,39.1133743 85.2536709,38.7732624 85.8032911,38.3250587 C86.3513924,37.8773578 86.6063291,37.2252881 86.5678481,36.3680954 C86.5678481,35.4731963 86.4210127,34.7620532 86.1268354,34.2366771 C85.8329114,33.7113009 85.4405063,33.3018092 84.9506329,33.0099615 C84.4602532,32.7181138 83.8916456,32.5232972 83.2450633,32.4255119 C82.5977215,32.3294862 81.9010127,32.2797138 81.156962,32.2797138 C79.5098734,32.2797138 78.2159494,32.6303835 77.2746835,33.3312202 C76.3339241,34.0320569 75.7837975,35.2007046 75.6275949,36.8354037 L67.275443,36.8354037 C67.3924051,34.8892495 67.8817722,33.2726495 68.7455696,31.9886202 Z M85.2440506,43.6984752 C84.7149367,43.873433 84.1460759,44.0189798 83.5387342,44.1361211 C82.9306329,44.253011 82.2936709,44.350545 81.6270886,44.4279688 C80.96,44.5066495 80.2934177,44.6034294 79.6273418,44.7203193 C78.9994937,44.8362037 78.3820253,44.9933138 77.7749367,45.1871248 C77.1663291,45.3829468 76.636962,45.6451321 76.1865823,45.9759431 C75.7349367,46.3070055 75.3724051,46.7263009 75.0979747,47.2313156 C74.8232911,47.7375872 74.6863291,48.380356 74.6863291,49.1588679 C74.6863291,49.8979138 74.8232911,50.5218294 75.0979747,51.026844 C75.3724051,51.5338697 75.7455696,51.9328037 76.2159494,52.2246514 C76.6863291,52.5164991 77.2349367,52.7213706 77.8632911,52.8375064 C78.4898734,52.9546477 79.136962,53.012967 79.8037975,53.012967 C81.4506329,53.012967 82.724557,52.740978 83.6273418,52.1952404 C84.5288608,51.6507596 85.1949367,50.9981872 85.6270886,50.2382771 C86.0579747,49.4793725 86.323038,48.7119211 86.4212658,47.9321523 C86.518481,47.1536404 86.5681013,46.5304789 86.5681013,46.063422 L86.5681013,42.9677248 C86.2146835,43.2799339 85.7736709,43.5230147 85.2440506,43.6984752 Z" id="a"></path> + <path d="M116.917975,27.5493174 L116.917975,33.0976917 L110.801266,33.0976917 L110.801266,48.0492936 C110.801266,49.4502128 111.036203,50.3850807 111.507089,50.8518862 C111.976962,51.3191945 112.918734,51.5527229 114.33038,51.5527229 C114.801013,51.5527229 115.251392,51.5336183 115.683038,51.4944037 C116.114177,51.4561945 116.526076,51.3968697 116.917975,51.3194459 L116.917975,57.7438661 C116.212152,57.860756 115.427595,57.9381798 114.565316,57.9778972 C113.702785,58.0153523 112.859747,58.0357138 112.036203,58.0357138 C110.742278,58.0357138 109.516456,57.9477321 108.36,57.7722716 C107.202785,57.5975651 106.183544,57.2577046 105.301519,56.7509303 C104.418987,56.2454128 103.722785,55.5242147 103.213418,54.5898495 C102.703038,53.6562385 102.448608,52.4292716 102.448608,50.9099541 L102.448608,33.0976917 L97.3903797,33.0976917 L97.3903797,27.5493174 L102.448608,27.5493174 L102.448608,18.4967596 L110.801013,18.4967596 L110.801013,27.5493174 L116.917975,27.5493174 Z" id="t"></path> + <path d="M128.857975,27.5493174 L128.857975,33.1565138 L128.975696,33.1565138 C129.367089,32.2213945 129.896203,31.3559064 130.563544,30.557033 C131.23038,29.7596679 131.99443,29.0776844 132.857215,28.5130936 C133.719241,27.9495083 134.641266,27.5113596 135.622532,27.1988991 C136.601772,26.8879468 137.622025,26.7315908 138.681013,26.7315908 C139.229873,26.7315908 139.836962,26.8296275 140.504304,27.0239413 L140.504304,34.7336477 C140.111646,34.6552183 139.641013,34.586844 139.092658,34.5290275 C138.543291,34.4704569 138.014177,34.4410459 137.504304,34.4410459 C135.974937,34.4410459 134.681013,34.6949358 133.622785,35.2004532 C132.564051,35.7067248 131.711392,36.397255 131.064051,37.2735523 C130.417215,38.1501009 129.955443,39.1714422 129.681266,40.3398385 C129.407089,41.5074807 129.269873,42.7736624 129.269873,44.1361211 L129.269873,57.7438661 L120.917722,57.7438661 L120.917722,27.5493174 L128.857975,27.5493174 Z" id="r"></path> + <path d="M144.033165,22.8767376 L144.033165,16.0435798 L152.386076,16.0435798 L152.386076,22.8767376 L144.033165,22.8767376 Z M152.386076,27.5493174 L152.386076,57.7438661 L144.033165,57.7438661 L144.033165,27.5493174 L152.386076,27.5493174 Z" id="i"></path> + <polygon id="x" points="156.738228 27.5493174 166.266582 27.5493174 171.619494 35.4337303 176.913418 27.5493174 186.147848 27.5493174 176.148861 41.6831927 187.383544 57.7441174 177.85443 57.7441174 171.501772 48.2245028 165.148861 57.7441174 155.797468 57.7441174 166.737468 41.8589046"></polygon> + <polygon id="right-bracket" points="197.580759 82.7268844 197.580759 1.93811009 191.725063 1.93811009 191.725063 0 199.828354 0 199.828354 84.6652459 191.725063 84.6652459 191.725063 82.7268844"></polygon> + </g> + </g> + </g> + </svg> + <p>An open network for secure, decentralized communication.<br>© 2021 The Matrix.org Foundation C.I.C.</p> +</footer> \ No newline at end of file diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
index 62a640dad2..53b82db84e 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -2,30 +2,60 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <link rel="stylesheet" href="/_matrix/static/client/login/style.css"> - <title>{{ server_name }} Login</title> + <title>Choose identity provider</title> + <style type="text/css"> + {% include "sso.css" without context %} + + .providers { + list-style: none; + padding: 0; + } + + .providers li { + margin: 12px; + } + + .providers a { + display: block; + border-radius: 4px; + border: 1px solid #17191C; + padding: 8px; + text-align: center; + text-decoration: none; + color: #17191C; + display: flex; + align-items: center; + font-weight: bold; + } + + .providers a img { + width: 24px; + height: 24px; + } + .providers a span { + flex: 1; + } + </style> </head> <body> - <div id="container"> - <h1 id="title">{{ server_name }} Login</h1> - <div class="login_flow"> - <p>Choose one of the following identity providers:</p> - <form> - <input type="hidden" name="redirectUrl" value="{{ redirect_url }}"> - <ul class="radiobuttons"> -{% for p in providers %} - <li> - <input type="radio" name="idp" id="prov{{ loop.index }}" value="{{ p.idp_id }}"> - <label for="prov{{ loop.index }}">{{ p.idp_name }}</label> -{% if p.idp_icon %} + <header> + <h1>Log in to {{ server_name }} </h1> + <p>Choose an identity provider to log in</p> + </header> + <main> + <ul class="providers"> + {% for p in providers %} + <li> + <a href="pick_idp?idp={{ p.idp_id }}&redirectUrl={{ redirect_url | urlencode }}"> + {% if p.idp_icon %} <img src="{{ p.idp_icon | mxc_to_http(32, 32) }}"/> -{% endif %} - </li> -{% endfor %} - </ul> - <input type="submit" class="button button--full-width" id="button-submit" value="Submit"> - </form> - </div> - </div> + {% endif %} + <span>{{ p.idp_name }}</span> + </a> + </li> + {% endfor %} + </ul> + </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html
index 8c33787c54..68c8b9f33a 100644 --- a/synapse/res/templates/sso_new_user_consent.html +++ b/synapse/res/templates/sso_new_user_consent.html
@@ -2,7 +2,7 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>SSO redirect confirmation</title> + <title>Agree to terms and conditions</title> <meta name="viewport" content="width=device-width, user-scalable=no"> <style type="text/css"> {% include "sso.css" without context %} @@ -18,22 +18,15 @@ <p>Agree to the terms to create your account.</p> </header> <main> - <!-- {% if user_profile.avatar_url and user_profile.display_name %} --> - <div class="profile"> - <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> - <div class="profile-details"> - <div class="display-name">{{ user_profile.display_name }}</div> - <div class="user-id">{{ user_id }}</div> - </div> - </div> - <!-- {% endif %} --> + {% include "sso_partial_profile.html" %} <form method="post" action="{{my_url}}" id="consent_form"> <p> <input id="accepted_version" type="checkbox" name="accepted_version" value="{{ consent_version }}" required> - <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank">terms and conditions</a>.</label> + <label for="accepted_version">I have read and agree to the <a href="{{ terms_url }}" target="_blank" rel="noopener">terms and conditions</a>.</label> </p> <input type="submit" class="primary-button" value="Continue"/> </form> </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/res/templates/sso_partial_profile.html b/synapse/res/templates/sso_partial_profile.html new file mode 100644
index 0000000000..c9c76c455e --- /dev/null +++ b/synapse/res/templates/sso_partial_profile.html
@@ -0,0 +1,19 @@ +{# html fragment to be included in SSO pages, to show the user's profile #} + +<div class="profile{% if user_profile.avatar_url %} with-avatar{% endif %}"> + {% if user_profile.avatar_url %} + <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> + {% endif %} + {# users that signed up with SSO will have a display_name of some sort; + however that is not the case for users who signed up via other + methods, so we need to handle that. + #} + {% if user_profile.display_name %} + <div class="display-name">{{ user_profile.display_name }}</div> + {% else %} + {# split the userid on ':', take the part before the first ':', + and then remove the leading '@'. #} + <div class="display-name">{{ user_id.split(":")[0][1:] }}</div> + {% endif %} + <div class="user-id">{{ user_id }}</div> +</div> diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
index d1328a6969..1b01471ac8 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -2,35 +2,39 @@ <html lang="en"> <head> <meta charset="UTF-8"> - <title>SSO redirect confirmation</title> + <title>Continue to your account</title> <meta name="viewport" content="width=device-width, user-scalable=no"> <style type="text/css"> {% include "sso.css" without context %} + + .confirm-trust { + margin: 34px 0; + color: #8D99A5; + } + .confirm-trust strong { + color: #17191C; + } + + .confirm-trust::before { + content: ""; + background-image: url('data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTgiIGhlaWdodD0iMTgiIHZpZXdCb3g9IjAgMCAxOCAxOCIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZmlsbC1ydWxlPSJldmVub2RkIiBjbGlwLXJ1bGU9ImV2ZW5vZGQiIGQ9Ik0xNi41IDlDMTYuNSAxMy4xNDIxIDEzLjE0MjEgMTYuNSA5IDE2LjVDNC44NTc4NiAxNi41IDEuNSAxMy4xNDIxIDEuNSA5QzEuNSA0Ljg1Nzg2IDQuODU3ODYgMS41IDkgMS41QzEzLjE0MjEgMS41IDE2LjUgNC44NTc4NiAxNi41IDlaTTcuMjUgOUM3LjI1IDkuNDY1OTYgNy41Njg2OSA5Ljg1NzQ4IDggOS45Njg1VjEyLjM3NUM4IDEyLjkyNzMgOC40NDc3MiAxMy4zNzUgOSAxMy4zNzVIMTAuMTI1QzEwLjY3NzMgMTMuMzc1IDExLjEyNSAxMi45MjczIDExLjEyNSAxMi4zNzVDMTEuMTI1IDExLjgyMjcgMTAuNjc3MyAxMS4zNzUgMTAuMTI1IDExLjM3NUgxMFY5QzEwIDguOTY1NDggOS45OTgyNSA4LjkzMTM3IDkuOTk0ODQgOC44OTc3NkM5Ljk0MzYzIDguMzkzNSA5LjUxNzc3IDggOSA4SDguMjVDNy42OTc3MiA4IDcuMjUgOC40NDc3MiA3LjI1IDlaTTkgNy41QzkuNjIxMzIgNy41IDEwLjEyNSA2Ljk5NjMyIDEwLjEyNSA2LjM3NUMxMC4xMjUgNS43NTM2OCA5LjYyMTMyIDUuMjUgOSA1LjI1QzguMzc4NjggNS4yNSA3Ljg3NSA1Ljc1MzY4IDcuODc1IDYuMzc1QzcuODc1IDYuOTk2MzIgOC4zNzg2OCA3LjUgOSA3LjVaIiBmaWxsPSIjQzFDNkNEIi8+Cjwvc3ZnPgoK'); + background-repeat: no-repeat; + width: 24px; + height: 24px; + display: block; + float: left; + } </style> </head> <body> <header> - {% if new_user %} - <h1>Your account is now ready</h1> - <p>You've made your account on {{ server_name }}.</p> - {% else %} - <h1>Log in</h1> - {% endif %} - <p>Continue to confirm you trust <strong>{{ display_url }}</strong>.</p> + <h1>Continue to your account</h1> </header> <main> - {% if user_profile.avatar_url %} - <div class="profile"> - <img src="{{ user_profile.avatar_url | mxc_to_http(64, 64) }}" class="avatar" /> - <div class="profile-details"> - {% if user_profile.display_name %} - <div class="display-name">{{ user_profile.display_name }}</div> - {% endif %} - <div class="user-id">{{ user_id }}</div> - </div> - </div> - {% endif %} + {% include "sso_partial_profile.html" %} + <p class="confirm-trust">Continuing will grant <strong>{{ display_url }}</strong> access to your account.</p> <a href="{{ redirect_url }}" class="primary-button">Continue</a> </main> + {% include "sso_footer.html" without context %} </body> </html> diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 40f5c32db2..ee3a9af569 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -39,6 +39,7 @@ from synapse.rest.client.v2_alpha import ( filter, groups, keys, + knock, notifications, openid, password_policy, @@ -119,8 +120,10 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) password_policy.register_servlets(hs, client_resource) + knock.register_servlets(hs, client_resource) # moving to /_synapse/admin admin.register_servlets_for_client_rest_resource(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index f5c5d164f9..8457db1e22 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import ( JoinRoomAliasServlet, ListRoomRestServlet, MakeRoomAdminRestServlet, + RoomEventContextServlet, RoomMembersRestServlet, RoomRestServlet, RoomStateRestServlet, @@ -238,6 +239,7 @@ def register_servlets(hs, http_server): MakeRoomAdminRestServlet(hs).register(http_server) ShadowBanRestServlet(hs).register(http_server) ForwardExtremitiesRestServlet(hs).register(http_server) + RoomEventContextServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index d0c86b204a..ebc587aa06 100644 --- a/synapse/rest/admin/groups.py +++ b/synapse/rest/admin/groups.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__) class DeleteGroupAdminRestServlet(RestServlet): - """Allows deleting of local groups - """ + """Allows deleting of local groups""" PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)") diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 8720b1401f..b996862c05 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py
@@ -119,8 +119,7 @@ class QuarantineMediaByID(RestServlet): class ProtectMediaByID(RestServlet): - """Protect local media from being quarantined. - """ + """Protect local media from being quarantined.""" PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)") @@ -141,8 +140,7 @@ class ProtectMediaByID(RestServlet): class ListMediaInRoom(RestServlet): - """Lists all of the media in a given room. - """ + """Lists all of the media in a given room.""" PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media") @@ -180,8 +178,7 @@ class PurgeMediaCacheRestServlet(RestServlet): class DeleteMediaByID(RestServlet): - """Delete local media by a given ID. Removes it from this server. - """ + """Delete local media by a given ID. Removes it from this server.""" PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)") diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 3e57e6a4d0..e64582cffd 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -14,15 +14,18 @@ # limitations under the License. import logging from http import HTTPStatus -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple +from urllib import parse as urlparse from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.api.filtering import Filter from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_integer, parse_json_object_from_request, + parse_list_from_args, parse_string, ) from synapse.http.site import SynapseRequest @@ -33,6 +36,7 @@ from synapse.rest.admin._base import ( ) from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -362,10 +366,8 @@ class JoinRoomAliasServlet(RestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = [ - x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] - except Exception: + remote_room_hosts = parse_list_from_args(request.args, "server_name") + except KeyError: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler @@ -479,7 +481,8 @@ class MakeRoomAdminRestServlet(RestServlet): if not admin_user_id: raise SynapseError( - 400, "No local admin user in room", + 400, + "No local admin user in room", ) pl_content = power_levels.content @@ -489,7 +492,8 @@ class MakeRoomAdminRestServlet(RestServlet): admin_user_id = create_event.sender if not self.is_mine_id(admin_user_id): raise SynapseError( - 400, "No local admin user in room", + 400, + "No local admin user in room", ) # Grant the user power equal to the room admin by attempting to send an @@ -499,7 +503,8 @@ class MakeRoomAdminRestServlet(RestServlet): new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] fake_requester = create_requester( - admin_user_id, authenticated_entity=requester.authenticated_entity, + admin_user_id, + authenticated_entity=requester.authenticated_entity, ) try: @@ -605,3 +610,65 @@ class ForwardExtremitiesRestServlet(RestServlet): extremities = await self.store.get_forward_extremities_for_room(room_id) return 200, {"count": len(extremities), "results": extremities} + + +class RoomEventContextServlet(RestServlet): + """ + Provide the context for an event. + This API is designed to be used when system administrators wish to look at + an abuse report and understand what happened during and immediately prior + to this event. + """ + + PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$") + + def __init__(self, hs): + super().__init__() + self.clock = hs.get_clock() + self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() + + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=False) + await assert_user_is_admin(self.auth, requester.user) + + limit = parse_integer(request, "limit", default=10) + + # picking the API shape for symmetry with /messages + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] + else: + event_filter = None + + results = await self.room_context_handler.get_event_context( + requester, + room_id, + event_id, + limit, + event_filter, + use_admin_priviledge=True, + ) + + if not results: + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + + time_now = self.clock.time_msec() + results["events_before"] = await self._event_serializer.serialize_events( + results["events_before"], time_now + ) + results["event"] = await self._event_serializer.serialize_event( + results["event"], time_now + ) + results["events_after"] = await self._event_serializer.serialize_events( + results["events_after"], time_now + ) + results["state"] = await self._event_serializer.serialize_events( + results["state"], time_now + ) + + return 200, results diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 68c3c64a0d..998a0ef671 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -579,7 +579,7 @@ class ResetPasswordRestServlet(RestServlet): } Returns: 200 OK with empty object if success otherwise an error. - """ + """ PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)") @@ -752,7 +752,7 @@ class PushersRestServlet(RestServlet): Returns: pushers: Dictionary containing pushers information. - total: Number of pushers in dictonary `pushers`. + total: Number of pushers in dictionary `pushers`. """ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$") diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0fb9419e58..6e2fbedd99 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -310,7 +310,9 @@ class LoginRestServlet(RestServlet): except jwt.PyJWTError as e: # A JWT error occurred, return some info back to the client. raise LoginError( - 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, + 403, + "JWT validation failed: %s" % (str(e),), + errcode=Codes.FORBIDDEN, ) user = payload.get("sub", None) @@ -375,7 +377,9 @@ class SsoRedirectServlet(RestServlet): request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url, idp_id, + request, + client_redirect_url, + idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 23a529f8e3..94bfe2d1b0 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py
@@ -49,9 +49,7 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state( - state, self.clock.time_msec(), include_user_id=False - ) + state = format_user_presence_state(state, self.clock.time_msec()) return 200, state diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..d77e20e135 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/<paths> """ +from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() + self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -60,13 +62,34 @@ class ProfileDisplaynameRestServlet(RestServlet): new_name = content["displayname"] except Exception: raise SynapseError( - code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON, + code=400, + msg="Unable to parse name", + errcode=Codes.BAD_JSON, ) await self.profile_handler.set_displayname(user, requester, new_name, is_admin) + if self.hs.config.shadow_server: + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) + self.shadow_displayname(shadow_user.to_string(), content) + + return 200, {} + + def on_OPTIONS(self, request, user_id): return 200, {} + @defer.inlineCallbacks + def shadow_displayname(self, user_id, body): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.put_json( + "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, + ) + class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) @@ -75,6 +98,7 @@ class ProfileAvatarURLRestServlet(RestServlet): super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() + self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -113,8 +137,27 @@ class ProfileAvatarURLRestServlet(RestServlet): user, requester, new_avatar_url, is_admin ) + if self.hs.config.shadow_server: + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) + self.shadow_avatar_url(shadow_user.to_string(), content) + + return 200, {} + + def on_OPTIONS(self, request, user_id): return 200, {} + @defer.inlineCallbacks + def shadow_avatar_url(self, user_id, body): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.put_json( + "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, + ) + class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 89823fcc39..0c148a213d 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py
@@ -159,7 +159,9 @@ class PushersRemoveRestServlet(RestServlet): self.notifier.on_new_replication_data() respond_with_html_bytes( - request, 200, PushersRemoveRestServlet.SUCCESS_HTML, + request, + 200, + PushersRemoveRestServlet.SUCCESS_HTML, ) return None diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..d2612fd067 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -15,10 +15,9 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ - import logging import re -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from urllib import parse as urlparse from synapse.api.constants import EventTypes, Membership @@ -37,6 +36,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_integer, parse_json_object_from_request, + parse_list_from_args, parse_string, ) from synapse.logging.opentracing import set_tag @@ -283,10 +283,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): if RoomID.is_valid(room_identifier): room_id = room_identifier try: - remote_room_hosts = [ - x.decode("ascii") for x in request.args[b"server_name"] - ] # type: Optional[List[str]] - except Exception: + remote_room_hosts = parse_list_from_args(request.args, "server_name") + except KeyError: remote_room_hosts = None elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler @@ -362,7 +360,9 @@ class PublicRoomListRestServlet(TransactionRestServlet): parse_and_validate_server_name(server) except ValueError: raise SynapseError( - 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, ) try: @@ -413,7 +413,9 @@ class PublicRoomListRestServlet(TransactionRestServlet): parse_and_validate_server_name(server) except ValueError: raise SynapseError( - 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM, + 400, + "Invalid server name: %s" % (server,), + Codes.INVALID_PARAM, ) try: @@ -650,7 +652,7 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = await self.room_context_handler.get_event_context( - requester.user, room_id, event_id, limit, event_filter + requester, room_id, event_id, limit, event_filter ) if not results: @@ -739,7 +741,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): content["id_server"], requester, txn_id, - content.get("id_access_token"), + new_room=False, + id_access_token=content.get("id_access_token"), ) except ShadowBanError: # Pretend the request succeeded. diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a84a2fb385..701280d05f 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ # limitations under the License. import logging import random +import re from http import HTTPStatus from typing import TYPE_CHECKING from urllib.parse import urlparse @@ -38,6 +39,7 @@ from synapse.http.servlet import ( ) from synapse.metrics import threepid_send_requests from synapse.push.mailer import Mailer +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import canonicalise_email, check_3pid_allowed @@ -166,6 +168,7 @@ class PasswordRestServlet(RestServlet): self.datastore = self.hs.get_datastore() self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() + self.http_client = hs.get_simple_http_client() @interactive_auth_handler async def on_POST(self, request): @@ -191,24 +194,32 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - try: - params, session_id = await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "modify your account password", - ) - except InteractiveAuthIncompleteError as e: - # The user needs to provide more steps to complete auth, but - # they're not required to provide the password again. - # - # If a password is available now, hash the provided password and - # store it for later. - if new_password: - password_hash = await self.auth_handler.hash(new_password) - await self.auth_handler.set_session_data( - e.session_id, - UIAuthSessionDataConstants.PASSWORD_HASH, - password_hash, + # blindly trust ASes without UI-authing them + if requester.app_service: + params = body + else: + try: + ( + params, + session_id, + ) = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "modify your account password", ) - raise + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise user_id = requester.user.to_string() else: requester = None @@ -278,8 +289,28 @@ class PasswordRestServlet(RestServlet): user_id, password_hash, logout_devices, requester ) + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + await self.shadow_password(params, shadow_user.to_string()) + return 200, {} + def on_OPTIONS(self, _): + return 200, {} + + async def shadow_password(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") @@ -312,7 +343,10 @@ class DeactivateAccountRestServlet(RestServlet): return 200, {} await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "deactivate your account", + requester, + request, + body, + "deactivate your account", ) result = await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), @@ -374,10 +408,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not (await check_3pid_allowed(self.hs, "email", email)): raise SynapseError( 403, - "Your email domain is not authorized on this server", + "Your email is not authorized on this server", Codes.THREEPID_DENIED, ) @@ -455,7 +489,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( 403, "Account phone numbers are not authorized on this server", @@ -629,7 +663,8 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = hs.get_datastore() + self.http_client = hs.get_simple_http_client() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) @@ -648,6 +683,29 @@ class ThreepidRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) + # skip validation if this is a shadow 3PID from an AS + if requester.app_service: + # XXX: ASes pass in a validated threepid directly to bypass the IS. + # This makes the API entirely change shape when we have an AS token; + # it really should be an entirely separate API - perhaps + # /account/3pid/replicate or something. + threepid = body.get("threepid") + + await self.auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) + + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + + return 200, {} + threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") if threepid_creds is None: raise SynapseError( @@ -669,12 +727,35 @@ class ThreepidRestServlet(RestServlet): validation_session["address"], validation_session["validated_at"], ) + + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + threepid = { + "medium": validation_session["medium"], + "address": validation_session["address"], + "validated_at": validation_session["validated_at"], + } + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + async def shadow_3pid(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") @@ -685,6 +766,7 @@ class ThreepidAddRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.http_client = hs.get_simple_http_client() @interactive_auth_handler async def on_POST(self, request): @@ -703,7 +785,10 @@ class ThreepidAddRestServlet(RestServlet): assert_valid_client_secret(client_secret) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "add a third-party identifier to your account", + requester, + request, + body, + "add a third-party identifier to your account", ) validation_session = await self.identity_handler.validate_threepid_session( @@ -716,12 +801,33 @@ class ThreepidAddRestServlet(RestServlet): validation_session["address"], validation_session["validated_at"], ) + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + threepid = { + "medium": validation_session["medium"], + "address": validation_session["address"], + "validated_at": validation_session["validated_at"], + } + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + async def shadow_3pid(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") @@ -791,6 +897,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.http_client = hs.get_simple_http_client() async def on_POST(self, request): if not self.hs.config.enable_3pid_changes: @@ -815,6 +922,12 @@ class ThreepidDeleteRestServlet(RestServlet): logger.exception("Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid") + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + await self.shadow_3pid_delete(body, shadow_user.to_string()) + if ret: id_server_unbind_result = "success" else: @@ -822,6 +935,74 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} + async def shadow_3pid_delete(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + await self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + + +class ThreepidLookupRestServlet(RestServlet): + PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")] + + def __init__(self, hs): + super(ThreepidLookupRestServlet, self).__init__() + self.auth = hs.get_auth() + self.identity_handler = hs.get_identity_handler() + + async def on_GET(self, request): + """Proxy a /_matrix/identity/api/v1/lookup request to an identity + server + """ + await self.auth.get_user_by_req(request) + + # Verify query parameters + query_params = request.args + assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"]) + + # Retrieve needed information from query parameters + medium = parse_string(request, "medium") + address = parse_string(request, "address") + id_server = parse_string(request, "id_server") + + # Proxy the request to the identity server. lookup_3pid handles checking + # if the lookup is allowed so we don't need to do it here. + ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address) + + return 200, ret + + +class ThreepidBulkLookupRestServlet(RestServlet): + PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")] + + def __init__(self, hs): + super(ThreepidBulkLookupRestServlet, self).__init__() + self.auth = hs.get_auth() + self.identity_handler = hs.get_identity_handler() + + async def on_POST(self, request): + """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity + server + """ + await self.auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["threepids", "id_server"]) + + # Proxy the request to the identity server. lookup_3pid handles checking + # if the lookup is allowed so we don't need to do it here. + ret = await self.identity_handler.proxy_bulk_lookup_3pid( + body["id_server"], body["threepids"] + ) + + return 200, ret + def assert_valid_next_link(hs: "HomeServer", next_link: str): """ @@ -888,4 +1069,6 @@ def register_servlets(hs, http_server): ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) + ThreepidLookupRestServlet(hs).register(http_server) + ThreepidBulkLookupRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 3f28c0bc3e..c9f13e4ac5 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import UserID from ._base import client_patterns @@ -38,6 +39,9 @@ class AccountDataServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.handler = hs.get_account_data_handler() + self.notifier = hs.get_notifier() + self._is_worker = hs.config.worker_app is not None + self._profile_handler = hs.get_profile_handler() async def on_PUT(self, request, user_id, account_data_type): requester = await self.auth.get_user_by_req(request) @@ -46,7 +50,15 @@ class AccountDataServlet(RestServlet): body = parse_json_object_from_request(request) - await self.handler.add_account_data_for_user(user_id, account_data_type, body) + if account_data_type == "im.vector.hide_profile": + user = UserID.from_string(user_id) + hide_profile = body.get("hide_profile") + await self._profile_handler.set_active([user], not hide_profile, True) + + max_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, body + ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) return 200, {} diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index bd7f9ae203..40c5bd4d8c 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -37,24 +37,38 @@ class AccountValidityRenewServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.success_html = hs.config.account_validity.account_renewed_html_content - self.failure_html = hs.config.account_validity.invalid_token_html_content + self.account_renewed_template = ( + hs.config.account_validity_account_renewed_template + ) + self.account_previously_renewed_template = ( + hs.config.account_validity_account_previously_renewed_template + ) + self.invalid_token_template = hs.config.account_validity_invalid_token_template async def on_GET(self, request): if b"token" not in request.args: raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - token_valid = await self.account_activity_handler.renew_account( + ( + token_valid, + token_stale, + expiration_ts, + ) = await self.account_activity_handler.renew_account( renewal_token.decode("utf8") ) if token_valid: status_code = 200 - response = self.success_html + response = self.account_renewed_template.render(expiration_ts=expiration_ts) + elif token_stale: + status_code = 200 + response = self.account_previously_renewed_template.render( + expiration_ts=expiration_ts + ) else: status_code = 404 - response = self.failure_html + response = self.invalid_token_template.render(expiration_ts=expiration_ts) respond_with_html(request, status_code, response) @@ -72,10 +86,12 @@ class AccountValiditySendMailServlet(RestServlet): self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() self.auth = hs.get_auth() - self.account_validity = self.hs.config.account_validity + self.account_validity_renew_by_email_enabled = ( + self.hs.config.account_validity_renew_by_email_enabled + ) async def on_POST(self, request): - if not self.account_validity.renew_by_email_enabled: + if not self.account_validity_renew_by_email_enabled: raise AuthError( 403, "Account renewal via email is disabled on this server." ) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 314e01dfe4..3d07aadd39 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,7 +83,10 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "remove device(s) from your account", + requester, + request, + body, + "remove device(s) from your account", ) await self.device_handler.delete_devices( @@ -129,7 +132,10 @@ class DeviceRestServlet(RestServlet): raise await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "remove a device from your account", + requester, + request, + body, + "remove a device from your account", ) await self.device_handler.delete_device(requester.user.to_string(), device_id) @@ -206,7 +212,9 @@ class DehydratedDeviceServlet(RestServlet): if "device_data" not in submission: raise errors.SynapseError( - 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM, + 400, + "device_data missing", + errcode=errors.Codes.MISSING_PARAM, ) elif not isinstance(submission["device_data"], dict): raise errors.SynapseError( @@ -259,11 +267,15 @@ class ClaimDehydratedDeviceServlet(RestServlet): if "device_id" not in submission: raise errors.SynapseError( - 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM, + 400, + "device_id missing", + errcode=errors.Codes.MISSING_PARAM, ) elif not isinstance(submission["device_id"], str): raise errors.SynapseError( - 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM, + 400, + "device_id must be a string", + errcode=errors.Codes.INVALID_PARAM, ) result = await self.device_handler.rehydrate_device( diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..d3434225cb 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,29 @@ import logging from functools import wraps - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request -from synapse.types import GroupID +from typing import TYPE_CHECKING, Optional, Tuple + +from twisted.web.http import Request + +from synapse.api.constants import ( + MAX_GROUP_CATEGORYID_LENGTH, + MAX_GROUP_ROLEID_LENGTH, + MAX_GROUPID_LENGTH, +) +from synapse.api.errors import Codes, SynapseError +from synapse.handlers.groups_local import GroupsLocalHandler +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.types import GroupID, JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -33,7 +49,7 @@ def _validate_group_id(f): """ @wraps(f) - def wrapper(self, request, group_id, *args, **kwargs): + def wrapper(self, request: Request, group_id: str, *args, **kwargs): if not GroupID.is_valid(group_id): raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) @@ -43,19 +59,18 @@ def _validate_group_id(f): class GroupServlet(RestServlet): - """Get the group profile - """ + """Get the group profile""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -66,11 +81,17 @@ class GroupServlet(RestServlet): return 200, group_description @_validate_group_id - async def on_POST(self, request, group_id): + async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert_params_in_dict( + content, ("name", "avatar_url", "short_description", "long_description") + ) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot create group profiles." await self.groups_handler.update_group_profile( group_id, requester_user_id, content ) @@ -79,19 +100,18 @@ class GroupServlet(RestServlet): class GroupSummaryServlet(RestServlet): - """Get the full group summary - """ + """Get the full group summary""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -116,18 +136,34 @@ class GroupSummaryRoomsCatServlet(RestServlet): "/rooms/(?P<room_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, category_id, room_id): + async def on_PUT( + self, request: Request, group_id: str, category_id: Optional[str], room_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + if category_id == "": + raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) + + if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." resp = await self.groups_handler.update_group_summary_room( group_id, requester_user_id, @@ -139,10 +175,15 @@ class GroupSummaryRoomsCatServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, category_id, room_id): + async def on_DELETE( + self, request: Request, group_id: str, category_id: str, room_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group profiles." resp = await self.groups_handler.delete_group_summary_room( group_id, requester_user_id, room_id=room_id, category_id=category_id ) @@ -151,21 +192,22 @@ class GroupSummaryRoomsCatServlet(RestServlet): class GroupCategoryServlet(RestServlet): - """Get/add/update/delete a group category - """ + """Get/add/update/delete a group category""" PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id, category_id): + async def on_GET( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -176,11 +218,27 @@ class GroupCategoryServlet(RestServlet): return 200, category @_validate_group_id - async def on_PUT(self, request, group_id, category_id): + async def on_PUT( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + if not category_id: + raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) + + if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: + raise SynapseError( + 400, + "category_id may not be longer than %s characters" + % (MAX_GROUP_CATEGORYID_LENGTH,), + Codes.INVALID_PARAM, + ) + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." resp = await self.groups_handler.update_group_category( group_id, requester_user_id, category_id=category_id, content=content ) @@ -188,10 +246,15 @@ class GroupCategoryServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, category_id): + async def on_DELETE( + self, request: Request, group_id: str, category_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." resp = await self.groups_handler.delete_group_category( group_id, requester_user_id, category_id=category_id ) @@ -200,19 +263,18 @@ class GroupCategoryServlet(RestServlet): class GroupCategoriesServlet(RestServlet): - """Get all group categories - """ + """Get all group categories""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -224,19 +286,20 @@ class GroupCategoriesServlet(RestServlet): class GroupRoleServlet(RestServlet): - """Get/add/update/delete a group role - """ + """Get/add/update/delete a group role""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id, role_id): + async def on_GET( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -247,11 +310,27 @@ class GroupRoleServlet(RestServlet): return 200, category @_validate_group_id - async def on_PUT(self, request, group_id, role_id): + async def on_PUT( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + if not role_id: + raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) + + if len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group roles." resp = await self.groups_handler.update_group_role( group_id, requester_user_id, role_id=role_id, content=content ) @@ -259,10 +338,15 @@ class GroupRoleServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, role_id): + async def on_DELETE( + self, request: Request, group_id: str, role_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group roles." resp = await self.groups_handler.delete_group_role( group_id, requester_user_id, role_id=role_id ) @@ -271,19 +355,18 @@ class GroupRoleServlet(RestServlet): class GroupRolesServlet(RestServlet): - """Get all group roles - """ + """Get all group roles""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -308,18 +391,34 @@ class GroupSummaryUsersRoleServlet(RestServlet): "/users/(?P<user_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, role_id, user_id): + async def on_PUT( + self, request: Request, group_id: str, role_id: Optional[str], user_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + if role_id == "": + raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) + + if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: + raise SynapseError( + 400, + "role_id may not be longer than %s characters" + % (MAX_GROUP_ROLEID_LENGTH,), + Codes.INVALID_PARAM, + ) + content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." resp = await self.groups_handler.update_group_summary_user( group_id, requester_user_id, @@ -331,10 +430,15 @@ class GroupSummaryUsersRoleServlet(RestServlet): return 200, resp @_validate_group_id - async def on_DELETE(self, request, group_id, role_id, user_id): + async def on_DELETE( + self, request: Request, group_id: str, role_id: str, user_id: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group summaries." resp = await self.groups_handler.delete_group_summary_user( group_id, requester_user_id, user_id=user_id, role_id=role_id ) @@ -343,19 +447,18 @@ class GroupSummaryUsersRoleServlet(RestServlet): class GroupRoomServlet(RestServlet): - """Get all rooms in a group - """ + """Get all rooms in a group""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -367,19 +470,18 @@ class GroupRoomServlet(RestServlet): class GroupUsersServlet(RestServlet): - """Get all users in a group - """ + """Get all users in a group""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -391,19 +493,18 @@ class GroupUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet): - """Get users invited to a group - """ + """Get users invited to a group""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_GET(self, request, group_id): + async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -415,23 +516,25 @@ class GroupInvitedUsersServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet): - """Set group join policy - """ + """Set group join policy""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group join policy." result = await self.groups_handler.set_group_join_policy( group_id, requester_user_id, content ) @@ -440,19 +543,18 @@ class GroupSettingJoinPolicyServlet(RestServlet): class GroupCreateServlet(RestServlet): - """Create a group - """ + """Create a group""" PATTERNS = client_patterns("/create_group$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() self.server_name = hs.hostname - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -461,6 +563,19 @@ class GroupCreateServlet(RestServlet): localpart = content.pop("localpart") group_id = GroupID(localpart, self.server_name).to_string() + if not localpart: + raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) + + if len(group_id) > MAX_GROUPID_LENGTH: + raise SynapseError( + 400, + "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), + Codes.INVALID_PARAM, + ) + + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot create groups." result = await self.groups_handler.create_group( group_id, requester_user_id, content ) @@ -469,25 +584,29 @@ class GroupCreateServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet): - """Add a room to the group - """ + """Add a room to the group""" PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, room_id): + async def on_PUT( + self, request: Request, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify rooms in a group." result = await self.groups_handler.add_room_to_group( group_id, requester_user_id, room_id, content ) @@ -495,10 +614,15 @@ class GroupAdminRoomsServlet(RestServlet): return 200, result @_validate_group_id - async def on_DELETE(self, request, group_id, room_id): + async def on_DELETE( + self, request: Request, group_id: str, room_id: str + ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." result = await self.groups_handler.remove_room_from_group( group_id, requester_user_id, room_id ) @@ -507,26 +631,30 @@ class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet): - """Update the config of a room in a group - """ + """Update the config of a room in a group""" PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" "/config/(?P<config_key>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, room_id, config_key): + async def on_PUT( + self, request: Request, group_id: str, room_id: str, config_key: str + ): requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot modify group categories." result = await self.groups_handler.update_room_in_group( group_id, requester_user_id, room_id, config_key, content ) @@ -535,14 +663,13 @@ class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet): - """Invite a user to the group - """ + """Invite a user to the group""" PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() @@ -551,12 +678,15 @@ class GroupAdminUsersInviteServlet(RestServlet): self.is_mine_id = hs.is_mine_id @_validate_group_id - async def on_PUT(self, request, group_id, user_id): + async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) config = content.get("config", {}) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot invite users to a group." result = await self.groups_handler.invite( group_id, user_id, requester_user_id, config ) @@ -565,25 +695,27 @@ class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet): - """Kick a user from the group - """ + """Kick a user from the group""" PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" ) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id, user_id): + async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot kick users from a group." result = await self.groups_handler.remove_user_from_group( group_id, user_id, requester_user_id, content ) @@ -592,23 +724,25 @@ class GroupAdminUsersKickServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet): - """Leave a joined group - """ + """Leave a joined group""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot leave a group for a users." result = await self.groups_handler.remove_user_from_group( group_id, requester_user_id, requester_user_id, content ) @@ -617,23 +751,25 @@ class GroupSelfLeaveServlet(RestServlet): class GroupSelfJoinServlet(RestServlet): - """Attempt to join a group, or knock - """ + """Attempt to join a group, or knock""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot join a user to a group." result = await self.groups_handler.join_group( group_id, requester_user_id, content ) @@ -642,23 +778,25 @@ class GroupSelfJoinServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet): - """Accept a group invite - """ + """Accept a group invite""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() content = parse_json_object_from_request(request) + assert isinstance( + self.groups_handler, GroupsLocalHandler + ), "Workers cannot accept an invite to a group." result = await self.groups_handler.accept_invite( group_id, requester_user_id, content ) @@ -667,19 +805,18 @@ class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet): - """Update whether we publicise a users membership of a group - """ + """Update whether we publicise a users membership of a group""" PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @_validate_group_id - async def on_PUT(self, request, group_id): + async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) requester_user_id = requester.user.to_string() @@ -691,19 +828,18 @@ class GroupSelfUpdatePublicityServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet): - """Get the list of groups a user is advertising - """ + """Get the list of groups a user is advertising""" PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request, user_id): + async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) result = await self.groups_handler.get_publicised_groups_for_user(user_id) @@ -712,19 +848,18 @@ class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet): - """Get the list of groups a user is advertising - """ + """Get the list of groups a user is advertising""" PATTERNS = client_patterns("/publicised_groups$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() self.groups_handler = hs.get_groups_local_handler() - async def on_POST(self, request): + async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) @@ -736,18 +871,17 @@ class PublicisedGroupsForUsersServlet(RestServlet): class GroupsForUserServlet(RestServlet): - """Get all groups the logged in user is joined to - """ + """Get all groups the logged in user is joined to""" PATTERNS = client_patterns("/joined_groups$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() - async def on_GET(self, request): + async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() @@ -756,7 +890,7 @@ class GroupsForUserServlet(RestServlet): return 200, result -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) GroupInvitedUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a6134ead8a..f092e5b3a2 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,7 +271,10 @@ class SigningKeyUploadServlet(RestServlet): body = parse_json_object_from_request(request) await self.auth_handler.validate_user_via_ui_auth( - requester, request, body, "add a device signing key to your account", + requester, + request, + body, + "add a device signing key to your account", ) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) diff --git a/synapse/rest/client/v2_alpha/knock.py b/synapse/rest/client/v2_alpha/knock.py new file mode 100644
index 0000000000..8439da447e --- /dev/null +++ b/synapse/rest/client/v2_alpha/knock.py
@@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Sorunome +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Optional, Tuple + +from twisted.web.server import Request + +from synapse.api.constants import Membership +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_json_object_from_request, + parse_list_from_args, +) +from synapse.logging.opentracing import set_tag +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.types import JsonDict, RoomAlias, RoomID + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class KnockRoomAliasServlet(RestServlet): + """ + POST /xyz.amorgan.knock/{roomIdOrAlias} + """ + + PATTERNS = client_patterns( + "/xyz.amorgan.knock/(?P<room_identifier>[^/]*)", releases=() + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.txns = HttpTransactionCache(hs) + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + async def on_POST( + self, request: Request, room_identifier: str, txn_id: Optional[str] = None, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + content = parse_json_object_from_request(request) + event_content = None + if "reason" in content: + event_content = {"reason": content["reason"]} + + if RoomID.is_valid(room_identifier): + room_id = room_identifier + try: + remote_room_hosts = parse_list_from_args(request.args, "server_name") + except KeyError: + remote_room_hosts = None + elif RoomAlias.is_valid(room_identifier): + handler = self.room_member_handler + room_alias = RoomAlias.from_string(room_identifier) + room_id_obj, remote_room_hosts = await handler.lookup_room_alias(room_alias) + room_id = room_id_obj.to_string() + else: + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) + + await self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action=Membership.KNOCK, + txn_id=txn_id, + third_party_signed=None, + remote_room_hosts=remote_room_hosts, + content=event_content, + ) + + return 200, {"room_id": room_id} + + def on_PUT(self, request: Request, room_identifier: str, txn_id: str): + set_tag("txn_id", txn_id) + + return self.txns.fetch_or_execute_request( + request, self.on_POST, request, room_identifier, txn_id + ) + + +def register_servlets(hs, http_server): + KnockRoomAliasServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 10e1891174..a7aea914e9 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2015 - 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,7 @@ import hmac import logging import random +import re from typing import List, Union import synapse @@ -119,10 +121,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not (await check_3pid_allowed(self.hs, "email", body["email"])): raise SynapseError( 403, - "Your email domain is not authorized to register on this server", + "Your email is not authorized to register on this server", Codes.THREEPID_DENIED, ) @@ -193,6 +195,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): body, ["client_secret", "country", "phone_number", "send_attempt"] ) client_secret = body["client_secret"] + assert_valid_client_secret(client_secret) country = body["country"] phone_number = body["phone_number"] send_attempt = body["send_attempt"] @@ -200,7 +203,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + assert_valid_client_secret(body["client_secret"]) + + if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( 403, "Phone numbers are not authorized to register on this server", @@ -293,6 +298,7 @@ class RegistrationSubmitTokenServlet(RestServlet): sid = parse_string(request, "sid", required=True) client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) token = parse_string(request, "token", required=True) # Attempt to validate a 3PID session @@ -360,15 +366,9 @@ class UsernameAvailabilityRestServlet(RestServlet): 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) - ip = request.getClientIP() - with self.ratelimiter.ratelimit(ip) as wait_deferred: - await wait_deferred - - username = parse_string(request, "username", required=True) - - await self.registration_handler.check_username(username) - - return 200, {"available": True} + # We are not interested in logging in via a username in this deployment. + # Simply allow anything here as it won't be used later. + return 200, {"available": True} class RegisterRestServlet(RestServlet): @@ -418,18 +418,27 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # Pull out the provided username and do basic sanity checks early since - # the auth layer will store these in sessions. + # We don't care about usernames for this deployment. In fact, the act + # of checking whether they exist already can leak metadata about + # which users are already registered. + # + # Usernames are already derived via the provided email. + # So, if they're not necessary, just ignore them. + # + # (we do still allow appservices to set them below) desired_username = None - if "username" in body: - if not isinstance(body["username"], str) or len(body["username"]) > 512: - raise SynapseError(400, "Invalid username") - desired_username = body["username"] + + desired_display_name = body.get("display_name") appservice = None if self.auth.has_access_token(request): appservice = self.auth.get_appservice_by_req(request) + # We need to retrieve the password early in order to pass it to + # application service registration + # This is specific to shadow server registration of users via an AS + password = body.pop("password", None) + # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -438,7 +447,7 @@ class RegisterRestServlet(RestServlet): # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. - desired_username = body.get("user", desired_username) + desired_username = body.get("user", body.get("username")) # XXX we should check that desired_username is valid. Currently # we give appservices carte blanche for any insanity in mxids, @@ -451,7 +460,7 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Desired Username is missing or not a string") result = await self._do_appservice_registration( - desired_username, access_token, body + desired_username, password, desired_display_name, access_token, body ) return 200, result @@ -460,16 +469,6 @@ class RegisterRestServlet(RestServlet): if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN) - # For regular registration, convert the provided username to lowercase - # before attempting to register it. This should mean that people who try - # to register with upper-case in their usernames don't get a nasty surprise. - # - # Note that we treat usernames case-insensitively in login, so they are - # free to carry on imagining that their username is CrAzYh4cKeR if that - # keeps them happy. - if desired_username is not None: - desired_username = desired_username.lower() - # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) @@ -478,7 +477,6 @@ class RegisterRestServlet(RestServlet): # Note that we remove the password from the body since the auth layer # will store the body in the session and we don't want a plaintext # password store there. - password = body.pop("password", None) if password is not None: if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") @@ -508,19 +506,14 @@ class RegisterRestServlet(RestServlet): session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None ) - # Ensure that the username is valid. - if desired_username is not None: - await self.registration_handler.check_username( - desired_username, - guest_access_token=guest_access_token, - assigned_user_id=registered_user_id, - ) - # Check if the user-interactive authentication flows are complete, if # not this will raise a user-interactive auth error. try: auth_result, params, session_id = await self.auth_handler.check_ui_auth( - self._registration_flows, request, body, "register a new account", + self._registration_flows, + request, + body, + "register a new account", ) except InteractiveAuthIncompleteError as e: # The user needs to provide more steps to complete auth. @@ -552,7 +545,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - if not check_3pid_allowed(self.hs, medium, address): + if not (await check_3pid_allowed(self.hs, medium, address)): raise SynapseError( 403, "Third party identifiers (email/phone numbers)" @@ -560,6 +553,80 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + existingUid = await self.store.get_user_id_by_threepid( + medium, address + ) + + if existingUid is not None: + raise SynapseError( + 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE + ) + + if self.hs.config.register_mxid_from_3pid: + # override the desired_username based on the 3PID if any. + # reset it first to avoid folks picking their own username. + desired_username = None + + # we should have an auth_result at this point if we're going to progress + # to register the user (i.e. we haven't picked up a registered_user_id + # from our session store), in which case get ready and gen the + # desired_username + if auth_result: + if ( + self.hs.config.register_mxid_from_3pid == "email" + and LoginType.EMAIL_IDENTITY in auth_result + ): + address = auth_result[LoginType.EMAIL_IDENTITY]["address"] + desired_username = synapse.types.strip_invalid_mxid_characters( + address.replace("@", "-").lower() + ) + + # find a unique mxid for the account, suffixing numbers + # if needed + while True: + try: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + # if we got this far we passed the check. + break + except SynapseError as e: + if e.errcode == Codes.USER_IN_USE: + m = re.match(r"^(.*?)(\d+)$", desired_username) + if m: + desired_username = m.group(1) + str( + int(m.group(2)) + 1 + ) + else: + desired_username += "1" + else: + # something else went wrong. + break + + if self.hs.config.register_just_use_email_for_display_name: + desired_display_name = address + else: + # Custom mapping between email address and display name + desired_display_name = _map_email_to_displayname(address) + elif ( + self.hs.config.register_mxid_from_3pid == "msisdn" + and LoginType.MSISDN in auth_result + ): + desired_username = auth_result[LoginType.MSISDN]["address"] + else: + raise SynapseError( + 400, "Cannot derive mxid from 3pid; no recognised 3pid" + ) + + if desired_username is not None: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", registered_user_id @@ -574,7 +641,12 @@ class RegisterRestServlet(RestServlet): if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) - desired_username = params.get("username", None) + if not self.hs.config.register_mxid_from_3pid: + desired_username = params.get("username", None) + else: + # we keep the original desired_username derived from the 3pid above + pass + guest_access_token = params.get("guest_access_token", None) if desired_username is not None: @@ -623,6 +695,7 @@ class RegisterRestServlet(RestServlet): localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, + default_display_name=desired_display_name, threepid=threepid, address=client_addr, user_agent_ips=entries, @@ -635,6 +708,14 @@ class RegisterRestServlet(RestServlet): ): await self.store.upsert_monthly_active_user(registered_user_id) + if self.hs.config.shadow_server: + await self.registration_handler.shadow_register( + localpart=desired_username, + display_name=desired_display_name, + auth_result=auth_result, + params=params, + ) + # Remember that the user account has been registered (and the user # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( @@ -658,14 +739,42 @@ class RegisterRestServlet(RestServlet): return 200, return_dict - async def _do_appservice_registration(self, username, as_token, body): + async def _do_appservice_registration( + self, username, password, display_name, as_token, body + ): + # FIXME: appservice_register() is horribly duplicated with register() + # and they should probably just be combined together with a config flag. + + if password: + # Hash the password + # + # In mainline hashing of the password was moved further on in the registration + # flow, but we need it here for the AS use-case of shadow servers + password = await self.auth_handler.hash(password) user_id = await self.registration_handler.appservice_register( - username, as_token + username, as_token, password, display_name ) - return await self._create_registration_details( - user_id, body, is_appservice_ghost=True, + result = await self._create_registration_details( + user_id, + body, + is_appservice_ghost=True, ) + auth_result = body.get("auth_result") + if auth_result and LoginType.EMAIL_IDENTITY in auth_result: + threepid = auth_result[LoginType.EMAIL_IDENTITY] + await self.registration_handler.register_email_threepid( + user_id, threepid, result["access_token"] + ) + + if auth_result and LoginType.MSISDN in auth_result: + threepid = auth_result[LoginType.MSISDN] + await self.registration_handler.register_msisdn_threepid( + user_id, threepid, result["access_token"] + ) + + return result + async def _create_registration_details( self, user_id, params, is_appservice_ghost=False ): @@ -721,6 +830,60 @@ class RegisterRestServlet(RestServlet): ) +def cap(name): + """Capitalise parts of a name containing different words, including those + separated by hyphens. + For example, 'John-Doe' + + Args: + name (str): The name to parse + """ + if not name: + return name + + # Split the name by whitespace then hyphens, capitalizing each part then + # joining it back together. + capatilized_name = " ".join( + "-".join(part.capitalize() for part in space_part.split("-")) + for space_part in name.split() + ) + return capatilized_name + + +def _map_email_to_displayname(address): + """Custom mapping from an email address to a user displayname + + Args: + address (str): The email address to process + Returns: + str: The new displayname + """ + # Split the part before and after the @ in the email. + # Replace all . with spaces in the first part + parts = address.replace(".", " ").split("@") + + # Figure out which org this email address belongs to + org_parts = parts[1].split(" ") + + # If this is a ...matrix.org email, mark them as an Admin + if org_parts[-2] == "matrix" and org_parts[-1] == "org": + org = "Tchap Admin" + + # Is this is a ...gouv.fr address, set the org to whatever is before + # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their + # org as "gouv" + elif org_parts[-2] == "gouv" and org_parts[-1] == "fr": + org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2] + + # Otherwise, mark their org as the email's second-level domain name + else: + org = org_parts[-2] + + desired_display_name = cap(parts[0]) + " [" + cap(org) + "]" + + return desired_display_name + + def _calculate_registration_flows( # technically `config` has to provide *all* of these interfaces, not just one config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 18c75738f8..fe765da23c 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py
@@ -244,7 +244,9 @@ class RelationAggregationPaginationServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True, + room_id, + requester.user.to_string(), + allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to @@ -322,7 +324,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.check_user_in_room_or_world_readable( - room_id, requester.user.to_string(), allow_departed_users=True, + room_id, + requester.user.to_string(), + allow_departed_users=True, ) # This checks that a) the event exists and b) the user is allowed to diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index bf030e0ff4..147920767f 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) class RoomUpgradeRestServlet(RestServlet): - """Handler for room uprade requests. + """Handler for room upgrade requests. Handles requests of the form: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8e52e4cca4..582c999abd 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -12,11 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import itertools import logging +from typing import Any, Callable, Dict, List -from synapse.api.constants import PresenceState +from synapse.api.constants import Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( @@ -24,7 +24,7 @@ from synapse.events.utils import ( format_event_raw, ) from synapse.handlers.presence import format_user_presence_state -from synapse.handlers.sync import SyncConfig +from synapse.handlers.sync import KnockedSyncResult, SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken from synapse.util import json_decoder @@ -213,6 +213,10 @@ class SyncRestServlet(RestServlet): sync_result.invited, time_now, access_token_id, event_formatter ) + knocked = await self.encode_knocked( + sync_result.knocked, time_now, access_token_id, event_formatter + ) + archived = await self.encode_archived( sync_result.archived, time_now, @@ -230,11 +234,16 @@ class SyncRestServlet(RestServlet): "left": list(sync_result.device_lists.left), }, "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now), - "rooms": {"join": joined, "invite": invited, "leave": archived}, + "rooms": { + Membership.JOIN: joined, + Membership.INVITE: invited, + Membership.KNOCK: knocked, + Membership.LEAVE: archived, + }, "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, + Membership.JOIN: sync_result.groups.join, + Membership.INVITE: sync_result.groups.invite, + Membership.LEAVE: sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types, @@ -296,7 +305,7 @@ class SyncRestServlet(RestServlet): Args: rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of - sync results for rooms this user is joined to + sync results for rooms this user is invited to time_now(int): current time - used as a baseline for age calculations token_id(int): ID of the user's auth token - used for namespacing @@ -315,7 +324,7 @@ class SyncRestServlet(RestServlet): time_now, token_id=token_id, event_format=event_formatter, - is_invite=True, + include_stripped_room_state=True, ) unsigned = dict(invite.get("unsigned", {})) invite["unsigned"] = unsigned @@ -325,6 +334,60 @@ class SyncRestServlet(RestServlet): return invited + async def encode_knocked( + self, + rooms: List[KnockedSyncResult], + time_now: int, + token_id: int, + event_formatter: Callable[[Dict], Dict], + ) -> Dict[str, Dict[str, Any]]: + """ + Encode the rooms we've knocked on in a sync result. + + Args: + rooms: list of sync results for rooms this user is knocking on + time_now: current time - used as a baseline for age calculations + token_id: ID of the user's auth token - used for namespacing of transaction IDs + event_formatter: function to convert from federation format to client format + + Returns: + The list of rooms the user has knocked on, in our response format. + """ + knocked = {} + for room in rooms: + knock = await self._event_serializer.serialize_event( + room.knock, + time_now, + token_id=token_id, + event_format=event_formatter, + include_stripped_room_state=True, + ) + + # Extract the `unsigned` key from the knock event. + # This is where we (cheekily) store the knock state events + unsigned = knock.setdefault("unsigned", {}) + + # Duplicate the dictionary in order to avoid modifying the original + unsigned = dict(unsigned) + + # Extract the stripped room state from the unsigned dict + # This is for clients to get a little bit of information about + # the room they've knocked on, without revealing any sensitive information + knocked_state = list(unsigned.pop("knock_room_state", [])) + + # Append the actual knock membership event itself as well. This provides + # the client with: + # + # * A knock state event that they can use for easier internal tracking + # * The rough timestamp of when the knock occurred contained within the event + knocked_state.append(knock) + + # Build the `knock_state` dictionary, which will contain the state of the + # room that the client has knocked on + knocked[room.room_id] = {"knock_state": {"events": knocked_state}} + + return knocked + async def encode_archived( self, rooms, time_now, token_id, event_fields, event_formatter ): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index ad598cefe0..eeddfa31f8 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@ # limitations under the License. import logging +from typing import Dict -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from signedjson.sign import sign_json + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.types import UserID from ._base import client_patterns @@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() async def on_POST(self, request): """Searches for users in directory @@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet): body = parse_json_object_from_request(request) + if self.hs.config.user_directory_defer_to_id_server: + signed_body = sign_json( + body, self.hs.hostname, self.hs.config.signing_key[0] + ) + url = "%s/_matrix/identity/api/v1/user_directory/search" % ( + self.hs.config.user_directory_defer_to_id_server, + ) + resp = await self.http_client.post_json_get_json(url, signed_body) + return 200, resp + limit = body.get("limit", 10) limit = min(limit, 50) @@ -76,5 +95,124 @@ class UserDirectorySearchRestServlet(RestServlet): return 200, results +class SingleUserInfoServlet(RestServlet): + """ + Deprecated and replaced by `/users/info` + + GET /user/{user_id}/info HTTP/1.1 + """ + + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$") + + def __init__(self, hs): + super(SingleUserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + registry = hs.get_federation_registry() + + if not registry.query_handlers.get("user_info"): + registry.register_query_handler("user_info", self._on_federation_query) + + async def on_GET(self, request, user_id): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + # Attempt to make a federation request to the server that owns this user + args = {"user_id": user_id} + res = await self.transport_layer.make_query( + user.domain, "user_info", args, retry_on_dns_fail=True + ) + return 200, res + + user_id_to_info = await self.store.get_info_for_users([user_id]) + return 200, user_id_to_info[user_id] + + async def _on_federation_query(self, args): + """Called when a request for user information appears over federation + + Args: + args (dict): Dictionary of query arguments provided by the request + + Returns: + Deferred[dict]: Deactivation and expiration information for a given user + """ + user_id = args.get("user_id") + if not user_id: + raise SynapseError(400, "user_id not provided") + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + raise SynapseError(400, "User is not hosted on this homeserver") + + user_ids_to_info_dict = await self.store.get_info_for_users([user_id]) + return user_ids_to_info_dict[user_id] + + +class UserInfoServlet(RestServlet): + """Bulk version of `/user/{user_id}/info` endpoint + + GET /users/info HTTP/1.1 + + Returns a dictionary of user_id to info dictionary. Supports remote users + """ + + PATTERNS = client_patterns("/users/info$", unstable=True, releases=()) + + def __init__(self, hs): + super(UserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + + async def on_POST(self, request): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + # Extract the user_ids from the request + body = parse_json_object_from_request(request) + assert_params_in_dict(body, required=["user_ids"]) + + user_ids = body["user_ids"] + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + # Separate local and remote users + local_user_ids = set() + remote_server_to_user_ids = {} # type: Dict[str, set] + for user_id in user_ids: + user = UserID.from_string(user_id) + + if self.hs.is_mine(user): + local_user_ids.add(user_id) + else: + remote_server_to_user_ids.setdefault(user.domain, set()) + remote_server_to_user_ids[user.domain].add(user_id) + + # Retrieve info of all local users + user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids) + + # Request info of each remote user from their remote homeserver + for server_name, user_id_set in remote_server_to_user_ids.items(): + # Make a request to the given server about their own users + res = await self.transport_layer.get_info_of_users( + server_name, list(user_id_set) + ) + + user_id_to_info_dict.update(res) + + return 200, user_id_to_info_dict + + def register_servlets(hs, http_server): UserDirectorySearchRestServlet(hs).register(http_server) + SingleUserInfoServlet(hs).register(http_server) + UserInfoServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index d24a199318..c9b9e7f5ff 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -72,9 +72,12 @@ class VersionsRestServlet(RestServlet): # MSC2326. "org.matrix.label_based_filtering": True, # Implements support for cross signing as described in MSC1756 - "org.matrix.e2e_cross_signing": True, + # "org.matrix.e2e_cross_signing": True, # Implements additional endpoints as described in MSC2432 "org.matrix.msc2432": True, + # Tchap does not currently assume this rule for r0.5.0 + # XXX: Remove this when it does + "m.lazy_load_members": True, # Implements additional endpoints as described in MSC2666 "uk.half-shot.msc2666": True, # Whether new rooms will be set to encrypted or not (based on presets). diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index f71a03a12d..90bbeca679 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -137,7 +137,7 @@ def add_file_headers( # section 3.6 [2] to be a `token` or a `quoted-string`, where a `token` # is (essentially) a single US-ASCII word, and a `quoted-string` is a # US-ASCII string surrounded by double-quotes, using backslash as an - # escape charater. Note that %-encoding is *not* permitted. + # escape character. Note that %-encoding is *not* permitted. # # `filename*` is defined to be an `ext-value`, which is defined in # RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`, diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 3ed219ae43..48f4433155 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py
@@ -51,7 +51,8 @@ class DownloadResource(DirectServeJsonResource): b" object-src 'self';", ) request.setHeader( - b"Referrer-Policy", b"no-referrer", + b"Referrer-Policy", + b"no-referrer", ) server_name, media_id, name = parse_media_id(request) if server_name == self.server_name: diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 4c9946a616..a0162d4255 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -184,7 +184,7 @@ class MediaRepository: async def get_local_media( self, request: Request, media_id: str, name: Optional[str] ) -> None: - """Responds to reqests for local media, if exists, or returns 404. + """Responds to requests for local media, if exists, or returns 404. Args: request: The incoming request. @@ -306,7 +306,7 @@ class MediaRepository: media_info = await self.store.get_cached_remote_media(server_name, media_id) # file_id is the ID we use to track the file locally. If we've already - # seen the file then reuse the existing ID, otherwise genereate a new + # seen the file then reuse the existing ID, otherwise generate a new # one. # If we have an entry in the DB, try and look for it @@ -325,7 +325,10 @@ class MediaRepository: # Failed to find the file anywhere, lets download it. try: - media_info = await self._download_remote_file(server_name, media_id,) + media_info = await self._download_remote_file( + server_name, + media_id, + ) except SynapseError: raise except Exception as e: @@ -351,7 +354,11 @@ class MediaRepository: responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name: str, media_id: str,) -> dict: + async def _download_remote_file( + self, + server_name: str, + media_id: str, + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. @@ -773,7 +780,11 @@ class MediaRepository: ) except Exception as e: thumbnail_exists = await self.store.get_remote_media_thumbnail( - server_name, media_id, t_width, t_height, t_type, + server_name, + media_id, + t_width, + t_height, + t_type, ) if not thumbnail_exists: raise e @@ -832,7 +843,10 @@ class MediaRepository: return await self._remove_local_media_from_disk([media_id]) async def delete_old_local_media( - self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True, + self, + before_ts: int, + size_gt: int = 0, + keep_profiles: bool = True, ) -> Tuple[List[str], int]: """ Delete local or remote media from this server by size and timestamp. Removes @@ -849,7 +863,9 @@ class MediaRepository: A tuple of (list of deleted media IDs, total deleted media IDs). """ old_media = await self.store.get_local_media_before( - before_ts, size_gt, keep_profiles, + before_ts, + size_gt, + keep_profiles, ) return await self._remove_local_media_from_disk(old_media) @@ -927,10 +943,10 @@ class MediaRepositoryResource(Resource): <thumbnail> - The thumbnail methods are "crop" and "scale". "scale" trys to return an + The thumbnail methods are "crop" and "scale". "scale" tries to return an image where either the width or the height is smaller than the requested size. The client should then scale and letterbox the image if it needs to - fit within a given rectangle. "crop" trys to return an image where the + fit within a given rectangle. "crop" tries to return an image where the width and height are close to the requested size and the aspect matches the requested size. The client should scale the image if it needs to fit within a given rectangle. diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..1057e638be 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib import logging import os import shutil -from typing import IO, TYPE_CHECKING, Any, Optional, Sequence +from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence + +import attr from twisted.internet.defer import Deferred from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender +from synapse.api.errors import NotFoundError from synapse.logging.context import defer_to_thread, make_deferred_yieldable +from synapse.util import Clock from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder @@ -58,6 +62,8 @@ class MediaStorage: self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers + self.spam_checker = hs.get_spam_checker() + self.clock = hs.get_clock() async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other @@ -79,8 +85,7 @@ class MediaStorage: return fname async def write_to_file(self, source: IO, output: IO): - """Asynchronously write the `source` to `output`. - """ + """Asynchronously write the `source` to `output`.""" await defer_to_thread(self.reactor, _write_file_synchronously, source, output) @contextlib.contextmanager @@ -127,18 +132,29 @@ class MediaStorage: f.flush() f.close() + spam = await self.spam_checker.check_media_file_for_spam( + ReadableFileWrapper(self.clock, fname), file_info + ) + if spam: + logger.info("Blocking media due to spam checker") + # Note that we'll delete the stored media, due to the + # try/except below. The media also won't be stored in + # the DB. + raise SpamMediaException() + for provider in self.storage_providers: await provider.store_file(path, file_info) finished_called[0] = True yield f, fname, finish - except Exception: + except Exception as e: try: os.remove(fname) except Exception: pass - raise + + raise e from None if not finished_called: raise Exception("Finished callback not called") @@ -302,3 +318,38 @@ class FileResponder(Responder): def __exit__(self, exc_type, exc_val, exc_tb): self.open_file.close() + + +class SpamMediaException(NotFoundError): + """The media was blocked by a spam checker, so we simply 404 the request (in + the same way as if it was quarantined). + """ + + +@attr.s(slots=True) +class ReadableFileWrapper: + """Wrapper that allows reading a file in chunks, yielding to the reactor, + and writing to a callback. + + This is simplified `FileSender` that takes an IO object rather than an + `IConsumer`. + """ + + CHUNK_SIZE = 2 ** 14 + + clock = attr.ib(type=Clock) + path = attr.ib(type=str) + + async def write_chunks_to(self, callback: Callable[[bytes], None]): + """Reads the file in chunks and calls the callback with each chunk.""" + + with open(self.path, "rb") as file: + while True: + chunk = file.read(self.CHUNK_SIZE) + if not chunk: + break + + callback(chunk) + + # We yield to the reactor by sleeping for 0 seconds. + await self.clock.sleep(0) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index bf3be653aa..89dc6b1c98 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,7 +58,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) +_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I) +_xml_encoding_match = re.compile( + br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I +) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) OG_TAG_NAME_MAXLEN = 50 @@ -146,8 +149,7 @@ class PreviewUrlResource(DirectServeJsonResource): treq_args={"browser_like_redirects": True}, ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_blacklist=hs.config.url_preview_ip_range_blacklist, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), + use_proxy=True, ) self.media_repo = media_repo self.primary_base_path = media_repo.primary_base_path @@ -300,24 +302,7 @@ class PreviewUrlResource(DirectServeJsonResource): with open(media_info["filename"], "rb") as file: body = file.read() - encoding = None - - # Let's try and figure out if it has an encoding set in a meta tag. - # Limit it to the first 1kb, since it ought to be in the meta tags - # at the top. - match = _charset_match.search(body[:1000]) - - # If we find a match, it should take precedence over the - # Content-Type header, so set it here. - if match: - encoding = match.group(1).decode("ascii") - - # If we don't find a match, we'll look at the HTTP Content-Type, and - # if that doesn't exist, we'll fall back to UTF-8. - if not encoding: - content_match = _content_type_match.match(media_info["media_type"]) - encoding = content_match.group(1) if content_match else "utf-8" - + encoding = get_html_media_encoding(body, media_info["media_type"]) og = decode_and_calc_og(body, media_info["uri"], encoding) # pre-cache the image for posterity @@ -594,8 +579,7 @@ class PreviewUrlResource(DirectServeJsonResource): ) async def _expire_url_cache_data(self) -> None: - """Clean up expired url cache content, media and thumbnails. - """ + """Clean up expired url cache content, media and thumbnails.""" # TODO: Delete from backup media store assert self._worker_run_media_background_jobs @@ -689,6 +673,48 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("No media removed from url cache") +def get_html_media_encoding(body: bytes, content_type: str) -> str: + """ + Get the encoding of the body based on the (presumably) HTML body or media_type. + + The precedence used for finding a character encoding is: + + 1. meta tag with a charset declared. + 2. The XML document's character encoding attribute. + 3. The Content-Type header. + 4. Fallback to UTF-8. + + Args: + body: The HTML document, as bytes. + content_type: The Content-Type header. + + Returns: + The character encoding of the body, as a string. + """ + # Limit searches to the first 1kb, since it ought to be at the top. + body_start = body[:1024] + + # Let's try and figure out if it has an encoding set in a meta tag. + match = _charset_match.search(body_start) + if match: + return match.group(1).decode("ascii") + + # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/> + + # If we didn't find a match, see if it an XML document with an encoding. + match = _xml_encoding_match.match(body_start) + if match: + return match.group(1).decode("ascii") + + # If we don't find a match, we'll look at the HTTP Content-Type, and + # if that doesn't exist, we'll fall back to UTF-8. + content_match = _content_type_match.match(content_type) + if content_match: + return content_match.group(1) + + return "utf-8" + + def decode_and_calc_og( body: bytes, media_uri: str, request_encoding: Optional[str] = None ) -> Dict[str, Optional[str]]: @@ -725,6 +751,11 @@ def decode_and_calc_og( def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]: # Attempt to parse the body. If this fails, log and return no metadata. tree = etree.fromstring(body_attempt, parser) + + # The data was successfully parsed, but no tree was found. + if tree is None: + return {} + return _calc_og(tree, media_uri) # Attempt to parse the body. If this fails, log and return no metadata. diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_string +from synapse.rest.media.v1.media_storage import SpamMediaException if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource): # disposition = headers.getRawHeaders(b"Content-Disposition")[0] # TODO(markjh): parse content-dispostion - content_uri = await self.media_repo.create_content( - media_type, upload_name, request.content, content_length, requester.user - ) + try: + content_uri = await self.media_repo.create_content( + media_type, upload_name, request.content, content_length, requester.user + ) + except SpamMediaException: + # For uploading of media we want to respond with a 400, instead of + # the default 404, as that would just be confusing. + raise SynapseError(400, "Bad content") logger.info("Uploaded content with URI %r", content_uri) diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..1af33f0a45 100644 --- a/synapse/rest/synapse/client/oidc/callback_resource.py +++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -12,19 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging +from typing import TYPE_CHECKING from synapse.http.server import DirectServeHtmlResource +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) class OIDCCallbackResource(DirectServeHtmlResource): isLeaf = 1 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._oidc_handler = hs.get_oidc_handler() async def _async_render_GET(self, request): await self._oidc_handler.handle_oidc_callback(request) + + async def _async_render_POST(self, request): + # the auth response can be returned via an x-www-form-urlencoded form instead + # of GET params, as per + # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html. + await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py new file mode 100644
index 0000000000..e69de29bb2 --- /dev/null +++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py new file mode 100644
index 0000000000..6f2a1931c5 --- /dev/null +++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.config._base import ConfigError + +logger = logging.getLogger(__name__) + + +class DomainRuleChecker(object): + """ + A re-implementation of the SpamChecker that prevents users in one domain from + inviting users in other domains to rooms, based on a configuration. + + Takes a config in the format: + + spam_checker: + module: "rulecheck.DomainRuleChecker" + config: + domain_mapping: + "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ] + "other_inviter_domain": [ "invitee_domain_permitted" ] + default: False + + # Only let local users join rooms if they were explicitly invited. + can_only_join_rooms_with_invite: false + + # Only let local users create rooms if they are inviting only one + # other user, and that user matches the rules above. + can_only_create_one_to_one_rooms: false + + # Only let local users invite during room creation, regardless of the + # domain mapping rules above. + can_only_invite_during_room_creation: false + + # Prevent local users from inviting users from certain domains to + # rooms published in the room directory. + domains_prevented_from_being_invited_to_published_rooms: [] + + # Allow third party invites + can_invite_by_third_party_id: true + + Don't forget to consider if you can invite users from your own domain. + """ + + def __init__(self, config): + self.domain_mapping = config["domain_mapping"] or {} + self.default = config["default"] + + self.can_only_join_rooms_with_invite = config.get( + "can_only_join_rooms_with_invite", False + ) + self.can_only_create_one_to_one_rooms = config.get( + "can_only_create_one_to_one_rooms", False + ) + self.can_only_invite_during_room_creation = config.get( + "can_only_invite_during_room_creation", False + ) + self.can_invite_by_third_party_id = config.get( + "can_invite_by_third_party_id", True + ) + self.domains_prevented_from_being_invited_to_published_rooms = config.get( + "domains_prevented_from_being_invited_to_published_rooms", [] + ) + + def check_event_for_spam(self, event): + """Implements synapse.events.SpamChecker.check_event_for_spam + """ + return False + + def user_may_invite( + self, + inviter_userid, + invitee_userid, + third_party_invite, + room_id, + new_room, + published_room=False, + ): + """Implements synapse.events.SpamChecker.user_may_invite + """ + if self.can_only_invite_during_room_creation and not new_room: + return False + + if not self.can_invite_by_third_party_id and third_party_invite: + return False + + # This is a third party invite (without a bound mxid), so unless we have + # banned all third party invites (above) we allow it. + if not invitee_userid: + return True + + inviter_domain = self._get_domain_from_id(inviter_userid) + invitee_domain = self._get_domain_from_id(invitee_userid) + + if inviter_domain not in self.domain_mapping: + return self.default + + if ( + published_room + and invitee_domain + in self.domains_prevented_from_being_invited_to_published_rooms + ): + return False + + return invitee_domain in self.domain_mapping[inviter_domain] + + def user_may_create_room( + self, userid, invite_list, third_party_invite_list, cloning + ): + """Implements synapse.events.SpamChecker.user_may_create_room + """ + + if cloning: + return True + + if not self.can_invite_by_third_party_id and third_party_invite_list: + return False + + number_of_invites = len(invite_list) + len(third_party_invite_list) + + if self.can_only_create_one_to_one_rooms and number_of_invites != 1: + return False + + return True + + def user_may_create_room_alias(self, userid, room_alias): + """Implements synapse.events.SpamChecker.user_may_create_room_alias + """ + return True + + def user_may_publish_room(self, userid, room_id): + """Implements synapse.events.SpamChecker.user_may_publish_room + """ + return True + + def user_may_join_room(self, userid, room_id, is_invited): + """Implements synapse.events.SpamChecker.user_may_join_room + """ + if self.can_only_join_rooms_with_invite and not is_invited: + return False + + return True + + @staticmethod + def parse_config(config): + """Implements synapse.events.SpamChecker.parse_config + """ + if "default" in config: + return config + else: + raise ConfigError("No default set for spam_config DomainRuleChecker") + + @staticmethod + def _get_domain_from_id(mxid): + """Parses a string and returns the domain part of the mxid. + + Args: + mxid (str): a valid mxid + + Returns: + str: the domain part of the mxid + + """ + idx = mxid.find(":") + if idx == -1: + raise Exception("Invalid ID: %r" % (mxid,)) + return mxid[idx + 1 :] diff --git a/synapse/server.py b/synapse/server.py
index 9bdd3177d7..91d59b755a 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -24,8 +24,17 @@ import abc import functools import logging -import os -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + TypeVar, + Union, + cast, +) import twisted.internet.base import twisted.internet.tcp @@ -360,11 +369,7 @@ class HomeServer(metaclass=abc.ABCMeta): """ An HTTP client that uses configured HTTP(S) proxies. """ - return SimpleHttpClient( - self, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), - ) + return SimpleHttpClient(self, use_proxy=True) @cache_in_self def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: @@ -376,8 +381,7 @@ class HomeServer(metaclass=abc.ABCMeta): self, ip_whitelist=self.config.ip_range_whitelist, ip_blacklist=self.config.ip_range_blacklist, - http_proxy=os.getenvb(b"http_proxy"), - https_proxy=os.getenvb(b"HTTPS_PROXY"), + use_proxy=True, ) @cache_in_self @@ -588,7 +592,9 @@ class HomeServer(metaclass=abc.ABCMeta): return UserDirectoryHandler(self) @cache_in_self - def get_groups_local_handler(self): + def get_groups_local_handler( + self, + ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: if self.config.worker_app: return GroupsLocalWorkerHandler(self) else: diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index 8dd01fce76..6652451346 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class ResourceLimitsServerNotices: - """ Keeps track of whether the server has reached it's resource limit and + """Keeps track of whether the server has reached it's resource limit and ensures that the client is kept up to date. """ diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 3bd9ff8ca0..c3d6e80c49 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py
@@ -398,7 +398,7 @@ class StateHandler: async def resolve_state_groups_for_events( self, room_id: str, event_ids: Iterable[str] ) -> _StateCacheEntry: - """ Given a list of event_ids this method fetches the state at each + """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. Args: @@ -570,7 +570,9 @@ class StateResolutionHandler: return cache logger.info( - "Resolving state for %s with groups %s", room_id, list(group_names), + "Resolving state for %s with groups %s", + room_id, + list(group_names), ) state_groups_histogram.observe(len(state_groups_ids)) @@ -615,7 +617,7 @@ class StateResolutionHandler: event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be - used as a starting point fof finding the state we need; any missing + used as a starting point for finding the state we need; any missing events will be requested via state_map_factory. If None, all events will be fetched via state_res_store. @@ -656,11 +658,15 @@ class StateResolutionHandler: return self._report_biggest( - lambda i: i.cpu_time, "CPU time", _biggest_room_by_cpu_counter, + lambda i: i.cpu_time, + "CPU time", + _biggest_room_by_cpu_counter, ) self._report_biggest( - lambda i: i.db_time, "DB time", _biggest_room_by_db_counter, + lambda i: i.db_time, + "DB time", + _biggest_room_by_db_counter, ) self._state_res_metrics.clear() diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 85edae053d..ce255da6fd 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py
@@ -95,7 +95,11 @@ async def resolve_events_with_store( if event.room_id != room_id: raise Exception( "Attempting to state-resolve for room %s with event %s which is in %s" - % (room_id, event.event_id, event.room_id,) + % ( + room_id, + event.event_id, + event.room_id, + ) ) # get the ids of the auth events which allow us to authenticate the @@ -119,7 +123,11 @@ async def resolve_events_with_store( if event.room_id != room_id: raise Exception( "Attempting to state-resolve for room %s with event %s which is in %s" - % (room_id, event.event_id, event.room_id,) + % ( + room_id, + event.event_id, + event.room_id, + ) ) state_map.update(state_map_new) @@ -243,7 +251,7 @@ def _resolve_with_state( def _resolve_state_events( conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase] ) -> StateMap[EventBase]: - """ This is where we actually decide which of the conflicted state to + """This is where we actually decide which of the conflicted state to use. We resolve conflicts in the following order: diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e585954bd8..e73a548ee4 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py
@@ -118,7 +118,11 @@ async def resolve_events_with_store( if event.room_id != room_id: raise Exception( "Attempting to state-resolve for room %s with event %s which is in %s" - % (room_id, event.event_id, event.room_id,) + % ( + room_id, + event.event_id, + event.room_id, + ) ) full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map} diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c0d9d1240f..a3c52695e9 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py
@@ -43,8 +43,7 @@ __all__ = ["Databases", "DataStore"] class Storage: - """The high level interfaces for talking to various storage layers. - """ + """The high level interfaces for talking to various storage layers.""" def __init__(self, hs: "HomeServer", stores: Databases): # We include the main data store here mainly so that we don't have to diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 29b8ca676a..329660cf0f 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -77,7 +77,7 @@ class BackgroundUpdatePerformance: class BackgroundUpdater: - """ Background updates are updates to the database that run in the + """Background updates are updates to the database that run in the background. Each update processes a batch of data at once. We attempt to limit the impact of each update by monitoring how long each batch takes to process and autotuning the batch size. @@ -158,8 +158,7 @@ class BackgroundUpdater: return False async def has_completed_background_update(self, update_name: str) -> bool: - """Check if the given background update has finished running. - """ + """Check if the given background update has finished running.""" if self._all_done: return True @@ -198,7 +197,8 @@ class BackgroundUpdater: if not self._current_background_update: all_pending_updates = await self.db_pool.runInteraction( - "background_updates", get_background_updates_txn, + "background_updates", + get_background_updates_txn, ) if not all_pending_updates: # no work left to do diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d2ba4bd2fc..4646926449 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -85,8 +85,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { def make_pool( reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine ) -> adbapi.ConnectionPool: - """Get the connection pool for the database. - """ + """Get the connection pool for the database.""" # By default enable `cp_reconnect`. We need to fiddle with db_args in case # someone has explicitly set `cp_reconnect`. @@ -158,8 +157,8 @@ class LoggingDatabaseConnection: def commit(self) -> None: self.conn.commit() - def rollback(self, *args, **kwargs) -> None: - self.conn.rollback(*args, **kwargs) + def rollback(self) -> None: + self.conn.rollback() def __enter__(self) -> "Connection": self.conn.__enter__() @@ -244,12 +243,15 @@ class LoggingTransaction: assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) + def fetchone(self) -> Optional[Tuple]: + return self.txn.fetchone() + + def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: + return self.txn.fetchmany(size=size) + def fetchall(self) -> List[Tuple]: return self.txn.fetchall() - def fetchone(self) -> Tuple: - return self.txn.fetchone() - def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() @@ -429,8 +431,7 @@ class DatabasePool: ) def is_running(self) -> bool: - """Is the database pool currently running - """ + """Is the database pool currently running""" return self._db_pool.running async def _check_safe_to_upsert(self) -> None: @@ -543,7 +544,11 @@ class DatabasePool: # This can happen if the database disappears mid # transaction. transaction_logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, + "[TXN OPERROR] {%s} %s %d/%d", + name, + e, + i, + N, ) if i < N: i += 1 @@ -564,7 +569,9 @@ class DatabasePool: conn.rollback() except self.engine.module.Error as e1: transaction_logger.warning( - "[TXN EROLL] {%s} %s", name, e1, + "[TXN EROLL] {%s} %s", + name, + e1, ) continue raise @@ -754,6 +761,7 @@ class DatabasePool: Returns: A list of dicts where the key is the column header. """ + assert cursor.description is not None, "cursor.description was None" col_headers = [intern(str(column[0])) for column in cursor.description] results = [dict(zip(col_headers, row)) for row in cursor] return results @@ -1402,7 +1410,10 @@ class DatabasePool: @staticmethod def simple_select_onecol_txn( - txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str, + txn: LoggingTransaction, + table: str, + keyvalues: Dict[str, Any], + retcol: str, ) -> List[Any]: sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} @@ -1712,7 +1723,11 @@ class DatabasePool: desc: description of the transaction, for logging and metrics """ await self.runInteraction( - desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True, + desc, + self.simple_delete_one_txn, + table, + keyvalues, + db_autocommit=True, ) @staticmethod diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0c24325011..e84f8b42f7 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py
@@ -56,7 +56,10 @@ class Databases: database_config.databases, ) prepare_database( - db_conn, engine, hs.config, databases=database_config.databases, + db_conn, + engine, + hs.config, + databases=database_config.databases, ) database = DatabasePool(hs, database_config, engine) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5d0845588c..70b49854cf 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py
@@ -340,7 +340,7 @@ class DataStore( count = txn.fetchone()[0] sql = ( - "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url " + "SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, displayname, avatar_url " + sql_base + " ORDER BY u.name LIMIT ? OFFSET ?" ) diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index e550cbc866..03a38422a1 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore): return self.services_cache def get_if_app_services_interested_in_user(self, user_id: str) -> bool: - """Check if the user is one associated with an app service (exclusively) - """ + """Check if the user is one associated with an app service (exclusively)""" if self.exclusive_user_regex: return bool(self.exclusive_user_regex.match(user_id)) else: diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index ea1e8fb580..6d18e692b0 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py
@@ -280,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore): return batch_size async def _devices_last_seen_update(self, progress, batch_size): - """Background update to insert last seen info into devices table - """ + """Background update to insert last seen info into devices table""" last_user_id = progress.get("last_user_id", "") last_device_id = progress.get("last_device_id", "") @@ -363,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): @wrap_as_background_process("prune_old_user_ips") async def _prune_old_user_ips(self): - """Removes entries in user IPs older than the configured period. - """ + """Removes entries in user IPs older than the configured period.""" if self.user_ips_max_age is None: # Nothing to do @@ -565,7 +563,11 @@ class ClientIpStore(ClientIpWorkerStore): results = {} for key in self._batch_row_update: - uid, access_token, ip, = key + ( + uid, + access_token, + ip, + ) = key if uid == user_id: user_agent, _, last_seen = self._batch_row_update[key] results[(access_token, ip)] = (user_agent, last_seen) diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 31f70ac5ef..45ca6620a8 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py
@@ -450,7 +450,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): }, ) - # Add the messages to the approriate local device inboxes so that + # Add the messages to the appropriate local device inboxes so that # they'll be sent to the devices when they next sync. self._add_messages_to_local_device_inbox_txn( txn, stream_id, local_messages_by_user_then_device diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 659d8f245f..d327e9aa0b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -315,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore): # make sure we go through the devices in stream order device_ids = sorted( - user_devices.keys(), key=lambda i: query_map[(user_id, i)][0], + user_devices.keys(), + key=lambda i: query_map[(user_id, i)][0], ) for device_id in device_ids: @@ -366,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore): async def mark_as_sent_devices_by_remote( self, destination: str, stream_id: int ) -> None: - """Mark that updates have successfully been sent to the destination. - """ + """Mark that updates have successfully been sent to the destination.""" await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, @@ -681,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore): return results async def get_user_ids_requiring_device_list_resync( - self, user_ids: Optional[Collection[str]] = None, + self, + user_ids: Optional[Collection[str]] = None, ) -> Set[str]: """Given a list of remote users return the list of users that we should resync the device lists for. If None is given instead of a list, @@ -721,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore): ) async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: - """Mark that we no longer track device lists for remote user. - """ + """Mark that we no longer track device lists for remote user.""" def _mark_remote_user_device_list_as_unsubscribed_txn(txn): self.db_pool.simple_delete_txn( @@ -902,7 +902,8 @@ class DeviceWorkerStore(SQLBaseStore): logger.info("Pruned %d device list outbound pokes", count) await self.db_pool.runInteraction( - "_prune_old_outbound_device_pokes", _prune_txn, + "_prune_old_outbound_device_pokes", + _prune_txn, ) @@ -943,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): # clear out duplicate device list outbound pokes self.db_pool.updates.register_background_update_handler( - BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes, + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, + self._remove_duplicate_outbound_pokes, ) # a pair of background updates that were added during the 1.14 release cycle, @@ -1004,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): row = None for row in rows: self.db_pool.simple_delete_txn( - txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS}, + txn, + "device_lists_outbound_pokes", + {x: row[x] for x in KEY_COLS}, ) row["sent"] = False self.db_pool.simple_insert_txn( - txn, "device_lists_outbound_pokes", row, + txn, + "device_lists_outbound_pokes", + row, ) if row: self.db_pool.updates._background_update_progress_txn( - txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row}, + txn, + BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, + {"last_row": row}, ) return len(rows) @@ -1286,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): # we've done a full resync, so we remove the entry that says we need # to resync self.db_pool.simple_delete_txn( - txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, + txn, + table="device_lists_remote_resync", + keyvalues={"user_id": user_id}, ) async def add_device_change_to_streams( @@ -1336,7 +1346,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): stream_ids: List[str], ): txn.call_after( - self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1], + self._device_list_stream_cache.entity_has_changed, + user_id, + stream_ids[-1], ) min_stream_id = stream_ids[0] diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index e5060d4c46..267b948397 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py
@@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore): servers: Iterable[str], creator: Optional[str] = None, ) -> None: - """ Creates an association between a room alias and room_id/servers + """Creates an association between a room alias and room_id/servers Args: room_alias: The alias to create. @@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore): return room_id async def update_aliases_for_room( - self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, + self, + old_room_id: str, + new_room_id: str, + creator: Optional[str] = None, ) -> None: """Repoint all of the aliases for a given room, to a different room. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 309f1e865b..f1e7859d26 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def count_e2e_one_time_keys( self, user_id: str, device_id: str ) -> Dict[str, int]: - """ Count the number of one time keys the server has for a device + """Count the number of one time keys the server has for a device Returns: A mapping from algorithm to number of keys for that algorithm. """ @@ -494,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): ) def _get_bare_e2e_cross_signing_keys_bulk_txn( - self, txn: Connection, user_ids: List[str], + self, + txn: Connection, + user_ids: List[str], ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if @@ -556,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): return result def _get_e2e_cross_signing_signatures_txn( - self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + self, + txn: Connection, + keys: Dict[str, Dict[str, dict]], + from_user_id: str, ) -> Dict[str, Dict[str, dict]]: """Returns the cross-signing signatures made by a user on a set of keys. diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 8326640d20..18ddb92fcc 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py
@@ -71,7 +71,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return await self.get_events_as_list(event_ids) async def get_auth_chain_ids( - self, event_ids: Collection[str], include_given: bool = False, + self, + event_ids: Collection[str], + include_given: bool = False, ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. @@ -273,7 +275,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # origin chain. if origin_sequence_number <= chains.get(origin_chain_id, 0): chains[target_chain_id] = max( - target_sequence_number, chains.get(target_chain_id, 0), + target_sequence_number, + chains.get(target_chain_id, 0), ) seen_chains.add(target_chain_id) @@ -371,7 +374,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas # and state sets {A} and {B} then walking the auth chains of A and B # would immediately show that C is reachable by both. However, if we # stopped at C then we'd only reach E via the auth chain of B and so E - # would errornously get included in the returned difference. + # would erroneously get included in the returned difference. # # The other thing that we do is limit the number of auth chains we walk # at once, due to practical limits (i.e. we can only query the database @@ -497,7 +500,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas a_ids = new_aids - # Mark that the auth event is reachable by the approriate sets. + # Mark that the auth event is reachable by the appropriate sets. sets.intersection_update(event_to_missing_sets[event_id]) search.sort() @@ -632,8 +635,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) async def get_min_depth(self, room_id: str) -> int: - """For the given room, get the minimum depth we have seen for it. - """ + """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( "get_min_depth", self._get_min_depth_interaction, room_id ) @@ -858,12 +860,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas ) await self.db_pool.runInteraction( - "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn, + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn, ) class EventFederationStore(EventFederationWorkerStore): - """ Responsible for storing and serving up the various graphs associated + """Responsible for storing and serving up the various graphs associated with an event. Including the main event graph and the auth chains for an event. diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 438383abe1..78245ad5bd 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight): def _deserialize_action(actions, is_highlight): - """Custom deserializer for actions. This allows us to "compress" common actions - """ + """Custom deserializer for actions. This allows us to "compress" common actions""" if actions: return db_to_json(actions) @@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore): @cached(num_args=3, tree=True, max_entries=5000) async def get_unread_event_push_actions_by_room_for_user( - self, room_id: str, user_id: str, last_read_event_id: Optional[str], + self, + room_id: str, + user_id: str, + last_read_event_id: Optional[str], ) -> Dict[str, int]: """Get the notification count, the highlight count and the unread message count for a given user in a given room after the given read receipt. @@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore): ) def _get_unread_counts_by_receipt_txn( - self, txn, room_id, user_id, last_read_event_id, + self, + txn, + room_id, + user_id, + last_read_event_id, ): stream_ordering = None if last_read_event_id is not None: stream_ordering = self.get_stream_id_for_event_txn( - txn, last_read_event_id, allow_none=True, + txn, + last_read_event_id, + allow_none=True, ) if stream_ordering is None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ccda9f1caa..287606cb4f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -399,7 +399,9 @@ class PersistEventsStore: self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) def _persist_event_auth_chain_txn( - self, txn: LoggingTransaction, events: List[EventBase], + self, + txn: LoggingTransaction, + events: List[EventBase], ) -> None: # We only care about state events, so this if there are no state events. @@ -470,7 +472,11 @@ class PersistEventsStore: event_to_room_id = {e.event_id: e.room_id for e in state_events.values()} self._add_chain_cover_index( - txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + txn, + self.db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, ) @classmethod @@ -517,7 +523,10 @@ class PersistEventsStore: # simple_select_many, but this case happens rarely and almost always # with a single row.) auth_events = db_pool.simple_select_onecol_txn( - txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id", + txn, + "event_auth", + keyvalues={"event_id": event_id}, + retcol="auth_id", ) events_to_calc_chain_id_for.add(event_id) @@ -550,7 +559,9 @@ class PersistEventsStore: WHERE """ clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", missing_auth_chains, + txn.database_engine, + "event_id", + missing_auth_chains, ) txn.execute(sql + clause, args) @@ -704,7 +715,8 @@ class PersistEventsStore: if chain_map[a_id][0] != chain_id } for start_auth_id, end_auth_id in itertools.permutations( - event_to_auth_chain.get(event_id, []), r=2, + event_to_auth_chain.get(event_id, []), + r=2, ): if chain_links.exists_path_from( chain_map[start_auth_id], chain_map[end_auth_id] @@ -888,8 +900,7 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], ): - """Persist the mapping from transaction IDs to event IDs (if defined). - """ + """Persist the mapping from transaction IDs to event IDs (if defined).""" to_insert = [] for event, _ in events_and_contexts: @@ -909,7 +920,9 @@ class PersistEventsStore: if to_insert: self.db_pool.simple_insert_many_txn( - txn, table="event_txn_id", values=to_insert, + txn, + table="event_txn_id", + values=to_insert, ) def _update_current_state_txn( @@ -941,7 +954,9 @@ class PersistEventsStore: txn.execute(sql, (stream_id, self._instance_name, room_id)) self.db_pool.simple_delete_txn( - txn, table="current_state_events", keyvalues={"room_id": room_id}, + txn, + table="current_state_events", + keyvalues={"room_id": room_id}, ) else: # We're still in the room, so we update the current state as normal. @@ -1050,7 +1065,7 @@ class PersistEventsStore: # Figure out the changes of membership to invalidate the # `get_rooms_for_user` cache. # We find out which membership events we may have deleted - # and which we have added, then we invlidate the caches for all + # and which we have added, then we invalidate the caches for all # those users. members_changed = { state_key @@ -1608,8 +1623,7 @@ class PersistEventsStore: ) def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ + """Store a room member in the database.""" def str_or_none(val: Any) -> Optional[str]: return val if isinstance(val, str) else None @@ -2001,8 +2015,7 @@ class PersistEventsStore: @attr.s(slots=True) class _LinkMap: - """A helper type for tracking links between chains. - """ + """A helper type for tracking links between chains.""" # Stores the set of links as nested maps: source chain ID -> target chain ID # -> source sequence number -> target sequence number. @@ -2108,7 +2121,9 @@ class _LinkMap: yield (src_chain, src_seq, target_chain, target_seq) def exists_path_from( - self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int], + self, + src_tuple: Tuple[int, int], + target_tuple: Tuple[int, int], ) -> bool: """Checks if there is a path between the source chain ID/sequence and target chain ID/sequence. diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 5ca4fa6817..89274e75f7 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -32,8 +32,7 @@ logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True) class _CalculateChainCover: - """Return value for _calculate_chain_cover_txn. - """ + """Return value for _calculate_chain_cover_txn.""" # The last room_id/depth/stream processed. room_id = attr.ib(type=str) @@ -127,11 +126,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): ) self.db_pool.updates.register_background_update_handler( - "rejected_events_metadata", self._rejected_events_metadata, + "rejected_events_metadata", + self._rejected_events_metadata, ) self.db_pool.updates.register_background_update_handler( - "chain_cover", self._chain_cover_index, + "chain_cover", + self._chain_cover_index, ) async def _background_reindex_fields_sender(self, progress, batch_size): @@ -462,8 +463,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return num_handled async def _redactions_received_ts(self, progress, batch_size): - """Handles filling out the `received_ts` column in redactions. - """ + """Handles filling out the `received_ts` column in redactions.""" last_event_id = progress.get("last_event_id", "") def _redactions_received_ts_txn(txn): @@ -518,8 +518,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): return count async def _event_fix_redactions_bytes(self, progress, batch_size): - """Undoes hex encoded censored redacted event JSON. - """ + """Undoes hex encoded censored redacted event JSON.""" def _event_fix_redactions_bytes_txn(txn): # This update is quite fast due to new index. @@ -642,7 +641,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): LIMIT ? """ - txn.execute(sql, (last_event_id, batch_size,)) + txn.execute( + sql, + ( + last_event_id, + batch_size, + ), + ) return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore @@ -910,7 +915,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): # Annoyingly we need to gut wrench into the persit event store so that # we can reuse the function to calculate the chain cover for rooms. PersistEventsStore._add_chain_cover_index( - txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + txn, + self.db_pool, + event_to_room_id, + event_to_types, + event_to_auth_chain, ) return _CalculateChainCover( diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
index 0ac1da9c35..b3703ae161 100644 --- a/synapse/storage/databases/main/events_forward_extremities.py +++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -71,7 +71,9 @@ class EventForwardExtremitiesStore(SQLBaseStore): if txn.rowcount > 0: # Invalidate the cache self._invalidate_cache_and_stream( - txn, self.get_latest_event_ids_in_room, (room_id,), + txn, + self.get_latest_event_ids_in_room, + (room_id,), ) return txn.rowcount @@ -97,5 +99,6 @@ class EventForwardExtremitiesStore(SQLBaseStore): return self.db_pool.cursor_to_dict(txn) return await self.db_pool.runInteraction( - "get_forward_extremities_for_room", get_forward_extremities_for_room_txn, + "get_forward_extremities_for_room", + get_forward_extremities_for_room_txn, ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 71d823be72..7bab6aa009 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -120,7 +120,9 @@ class EventsWorkerStore(SQLBaseStore): # SQLite). if hs.get_instance_name() in hs.config.worker.writers.events: self._stream_id_gen = StreamIdGenerator( - db_conn, "events", "stream_ordering", + db_conn, + "events", + "stream_ordering", ) self._backfill_id_gen = StreamIdGenerator( db_conn, @@ -140,7 +142,8 @@ class EventsWorkerStore(SQLBaseStore): if hs.config.run_background_tasks: # We periodically clean out old transaction ID mappings self._clock.looping_call( - self._cleanup_old_transaction_ids, 5 * 60 * 1000, + self._cleanup_old_transaction_ids, + 5 * 60 * 1000, ) self._get_event_cache = LruCache( @@ -525,7 +528,7 @@ class EventsWorkerStore(SQLBaseStore): async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_types_to_include: List[EventTypes], + state_types_to_include: List[str], membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -1325,8 +1328,7 @@ class EventsWorkerStore(SQLBaseStore): return rows, to_token, True async def is_event_after(self, event_id1, event_id2): - """Returns True if event_id1 is after event_id2 in the stream - """ + """Returns True if event_id1 is after event_id2 in the stream""" to_1, so_1 = await self.get_event_ordering(event_id1) to_2, so_2 = await self.get_event_ordering(event_id2) return (to_1, so_1) > (to_2, so_2) @@ -1428,8 +1430,7 @@ class EventsWorkerStore(SQLBaseStore): @wrap_as_background_process("_cleanup_old_transaction_ids") async def _cleanup_old_transaction_ids(self): - """Cleans out transaction id mappings older than 24hrs. - """ + """Cleans out transaction id mappings older than 24hrs.""" def _cleanup_old_transaction_ids_txn(txn): sql = """ @@ -1440,5 +1441,6 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (one_day_ago,)) return await self.db_pool.runInteraction( - "_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn, + "_cleanup_old_transaction_ids", + _cleanup_old_transaction_ids_txn, ) diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 7218191965..ac07e0197b 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import TypedDict from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, db_to_json @@ -26,6 +28,9 @@ from synapse.util import json_encoder _DEFAULT_CATEGORY_ID = "" _DEFAULT_ROLE_ID = "" +# A room in a group. +_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool}) + class GroupServerWorkerStore(SQLBaseStore): async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]: @@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore): async def get_rooms_in_group( self, group_id: str, include_private: bool = False - ) -> List[Dict[str, Union[str, bool]]]: + ) -> List[_RoomInGroup]: """Retrieve the rooms that belong to a given group. Does not return rooms that lack members. @@ -123,7 +128,9 @@ class GroupServerWorkerStore(SQLBaseStore): ) async def get_rooms_for_summary_by_category( - self, group_id: str, include_private: bool = False, + self, + group_id: str, + include_private: bool = False, ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: """Get the rooms and categories that should be included in a summary request @@ -368,8 +375,7 @@ class GroupServerWorkerStore(SQLBaseStore): async def is_user_invited_to_local_group( self, group_id: str, user_id: str ) -> Optional[bool]: - """Has the group server invited a user? - """ + """Has the group server invited a user?""" return await self.db_pool.simple_select_one_onecol( table="group_invites", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -427,8 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore): ) async def get_publicised_groups_for_user(self, user_id: str) -> List[str]: - """Get all groups a user is publicising - """ + """Get all groups a user is publicising""" return await self.db_pool.simple_select_onecol( table="local_group_membership", keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, @@ -437,8 +442,7 @@ class GroupServerWorkerStore(SQLBaseStore): ) async def get_attestations_need_renewals(self, valid_until_ms): - """Get all attestations that need to be renewed until givent time - """ + """Get all attestations that need to be renewed until givent time""" def _get_attestations_need_renewals_txn(txn): sql = """ @@ -781,8 +785,7 @@ class GroupServerStore(GroupServerWorkerStore): profile: Optional[JsonDict], is_public: Optional[bool], ) -> None: - """Add/update room category for group - """ + """Add/update room category for group""" insertion_values = {} update_values = {"category_id": category_id} # This cannot be empty @@ -818,8 +821,7 @@ class GroupServerStore(GroupServerWorkerStore): profile: Optional[JsonDict], is_public: Optional[bool], ) -> None: - """Add/remove user role - """ + """Add/remove user role""" insertion_values = {} update_values = {"role_id": role_id} # This cannot be empty @@ -1012,8 +1014,7 @@ class GroupServerStore(GroupServerWorkerStore): ) async def add_group_invite(self, group_id: str, user_id: str) -> None: - """Record that the group server has invited a user - """ + """Record that the group server has invited a user""" await self.db_pool.simple_insert( table="group_invites", values={"group_id": group_id, "user_id": user_id}, @@ -1156,8 +1157,7 @@ class GroupServerStore(GroupServerWorkerStore): async def update_group_publicity( self, group_id: str, user_id: str, publicise: bool ) -> None: - """Update whether the user is publicising their membership of the group - """ + """Update whether the user is publicising their membership of the group""" await self.db_pool.simple_update_one( table="local_group_membership", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -1300,8 +1300,7 @@ class GroupServerStore(GroupServerWorkerStore): async def update_attestation_renewal( self, group_id: str, user_id: str, attestation: dict ) -> None: - """Update an attestation that we have renewed - """ + """Update an attestation that we have renewed""" await self.db_pool.simple_update_one( table="group_attestations_renewals", keyvalues={"group_id": group_id, "user_id": user_id}, @@ -1312,8 +1311,7 @@ class GroupServerStore(GroupServerWorkerStore): async def update_remote_attestion( self, group_id: str, user_id: str, attestation: dict ) -> None: - """Update an attestation that a remote has renewed - """ + """Update an attestation that a remote has renewed""" await self.db_pool.simple_update_one( table="group_attestations_remote", keyvalues={"group_id": group_id, "user_id": user_id}, diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 04ac2d0ced..d504323b03 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -33,8 +33,7 @@ db_binary_type = memoryview class KeyStore(SQLBaseStore): - """Persistence for signature verification keys - """ + """Persistence for signature verification keys""" @cached() def _get_server_verify_key(self, server_name_and_key_id): @@ -155,7 +154,7 @@ class KeyStore(SQLBaseStore): (server_name, key_id, from_server) triplet if one already existed. Args: server_name: The name of the server. - key_id: The identifer of the key this JSON is for. + key_id: The identifier of the key this JSON is for. from_server: The server this JSON was fetched from. ts_now_ms: The time now in milliseconds. ts_valid_until_ms: The time when this json stops being valid. @@ -182,7 +181,7 @@ class KeyStore(SQLBaseStore): async def get_server_keys_json( self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: - """Retrive the key json for a list of server_keys and key ids. + """Retrieve the key json for a list of server_keys and key ids. If no keys are found for a given server, key_id and source then that server, key_id, and source triplet entry will be an empty list. The JSON is returned as a byte array so that it can be efficiently diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index e017177655..a0313c3ccf 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py
@@ -169,7 +169,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_local_media_before( - self, before_ts: int, size_gt: int, keep_profiles: bool, + self, + before_ts: int, + size_gt: int, + keep_profiles: bool, ) -> List[str]: # to find files that have never been accessed (last_access_ts IS NULL) @@ -454,10 +457,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): ) async def get_remote_media_thumbnail( - self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str, + self, + origin: str, + media_id: str, + t_width: int, + t_height: int, + t_type: str, ) -> Optional[Dict[str, Any]]: - """Fetch the thumbnail info of given width, height and type. - """ + """Fetch the thumbnail info of given width, height and type.""" return await self.db_pool.simple_select_one( table="remote_media_cache_thumbnails", diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92e65aa640..614a418a15 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py
@@ -111,7 +111,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_sent_e2ee_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. + # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname sql = """ @@ -167,7 +167,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own - # hostname then thats your own fault. + # hostname then that's your own fault. like_clause = "%:" + self.hs.hostname sql = """ diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index dbbb99cb95..29edab34d4 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1, + cached_method_name="_get_presence_for_user", + list_name="user_ids", + num_args=1, ) async def get_presence_for_users(self, user_ids): rows = await self.db_pool.simple_select_many_batch( diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index 54ef0f1f54..2dcdf9beb0 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +13,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore from synapse.storage.databases.main.roommember import ProfileInfo +from synapse.types import UserID +from synapse.util.caches.descriptors import cached + +BATCH_SIZE = 100 class ProfileWorkerStore(SQLBaseStore): @@ -39,6 +44,7 @@ class ProfileWorkerStore(SQLBaseStore): avatar_url=profile["avatar_url"], display_name=profile["displayname"] ) + @cached(max_entries=5000) async def get_profile_displayname(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", @@ -47,6 +53,7 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_displayname", ) + @cached(max_entries=5000) async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]: return await self.db_pool.simple_select_one_onecol( table="profiles", @@ -55,6 +62,58 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) + async def get_latest_profile_replication_batch_number(self): + def f(txn): + txn.execute("SELECT MAX(batch) as maxbatch FROM profiles") + rows = self.db_pool.cursor_to_dict(txn) + return rows[0]["maxbatch"] + + return await self.db_pool.runInteraction( + "get_latest_profile_replication_batch_number", f + ) + + async def get_profile_batch(self, batchnum): + return await self.db_pool.simple_select_list( + table="profiles", + keyvalues={"batch": batchnum}, + retcols=("user_id", "displayname", "avatar_url", "active"), + desc="get_profile_batch", + ) + + async def assign_profile_batch(self): + def f(txn): + sql = ( + "UPDATE profiles SET batch = " + "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) " + "WHERE user_id in (" + " SELECT user_id FROM profiles WHERE batch is NULL limit ?" + ")" + ) + txn.execute(sql, (BATCH_SIZE,)) + return txn.rowcount + + return await self.db_pool.runInteraction("assign_profile_batch", f) + + async def get_replication_hosts(self): + def f(txn): + txn.execute( + "SELECT host, last_synced_batch FROM profile_replication_status" + ) + rows = self.db_pool.cursor_to_dict(txn) + return {r["host"]: r["last_synced_batch"] for r in rows} + + return await self.db_pool.runInteraction("get_replication_hosts", f) + + async def update_replication_batch_for_host( + self, host: str, last_synced_batch: int + ): + return await self.db_pool.simple_upsert( + table="profile_replication_status", + keyvalues={"host": host}, + values={"last_synced_batch": last_synced_batch}, + desc="update_replication_batch_for_host", + ) + async def get_from_remote_profile_cache( self, user_id: str ) -> Optional[Dict[str, Any]]: @@ -72,32 +131,95 @@ class ProfileWorkerStore(SQLBaseStore): ) async def set_profile_displayname( - self, user_localpart: str, new_displayname: Optional[str] + self, user_localpart: str, new_displayname: Optional[str], batchnum: int ) -> None: - await self.db_pool.simple_update_one( + # Invalidate the read cache for this user + self.get_profile_displayname.invalidate((user_localpart,)) + + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, + values={"displayname": new_displayname, "batch": batchnum}, desc="set_profile_displayname", + lock=False, # we can do this because user_id has a unique index ) async def set_profile_avatar_url( - self, user_localpart: str, new_avatar_url: Optional[str] + self, user_localpart: str, new_avatar_url: Optional[str], batchnum: int ) -> None: - await self.db_pool.simple_update_one( + # Invalidate the read cache for this user + self.get_profile_avatar_url.invalidate((user_localpart,)) + + await self.db_pool.simple_upsert( table="profiles", keyvalues={"user_id": user_localpart}, - updatevalues={"avatar_url": new_avatar_url}, + values={"avatar_url": new_avatar_url, "batch": batchnum}, desc="set_profile_avatar_url", + lock=False, # we can do this because user_id has a unique index + ) + + async def set_profiles_active( + self, users: List[UserID], active: bool, hide: bool, batchnum: int, + ) -> None: + """Given a set of users, set active and hidden flags on them. + + Args: + users: A list of UserIDs + active: Whether to set the users to active or inactive + hide: Whether to hide the users (withold from replication). If + False and active is False, users will have their profiles + erased + batchnum: The batch number, used for profile replication + """ + # Convert list of localparts to list of tuples containing localparts + user_localparts = [(user.localpart,) for user in users] + + # Generate list of value tuples for each user + value_names = ("active", "batch") + values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple] + + if not active and not hide: + # we are deactivating for real (not in hide mode) + # so clear the profile information + value_names += ("avatar_url", "displayname") + values = [v + (None, None) for v in values] + + return await self.db_pool.runInteraction( + "set_profiles_active", + self.db_pool.simple_upsert_many_txn, + table="profiles", + key_names=("user_id",), + key_values=user_localparts, + value_names=value_names, + value_values=values, + ) + + async def add_remote_profile_cache( + self, user_id: str, displayname: str, avatar_url: str + ) -> None: + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + await self.db_pool.simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", ) async def update_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> int: - return await self.db_pool.simple_update( + return await self.db_pool.simple_upsert( table="remote_profile_cache", keyvalues={"user_id": user_id}, - updatevalues={ + values={ "displayname": displayname, "avatar_url": avatar_url, "last_check": self._clock.time_msec(), @@ -118,8 +240,7 @@ class ProfileWorkerStore(SQLBaseStore): ) async def is_subscribed_remote_profile_for_user(self, user_id): - """Check whether we are interested in a remote user's profile. - """ + """Check whether we are interested in a remote user's profile.""" res = await self.db_pool.simple_select_one_onecol( table="group_users", keyvalues={"user_id": user_id}, @@ -145,8 +266,7 @@ class ProfileWorkerStore(SQLBaseStore): async def get_remote_profile_cache_entries_that_expire( self, last_checked: int ) -> List[Dict[str, str]]: - """Get all users who haven't been checked since `last_checked` - """ + """Get all users who haven't been checked since `last_checked`""" def _get_remote_profile_cache_entries_that_expire_txn(txn): sql = """ @@ -166,6 +286,17 @@ class ProfileWorkerStore(SQLBaseStore): class ProfileStore(ProfileWorkerStore): + def __init__(self, database, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "profile_replication_status_host_index", + index_name="profile_replication_status_idx", + table="profile_replication_status", + columns=["host"], + unique=True, + ) + async def add_remote_profile_cache( self, user_id: str, displayname: str, avatar_url: str ) -> None: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 711d5aa23d..9e58dc0e6a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -168,7 +168,9 @@ class PushRulesWorkerStore( ) @cachedList( - cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, + cached_method_name="get_push_rules_for_user", + list_name="user_ids", + num_args=1, ) async def bulk_get_push_rules(self, user_ids): if not user_ids: @@ -195,7 +197,9 @@ class PushRulesWorkerStore( use_new_defaults = user_id in self._users_new_default_push_rules results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), use_new_defaults, + rules, + enabled_map_by_user.get(user_id, {}), + use_new_defaults, ) return results diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 2687ef3e43..7cb69dd6bd 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -179,7 +179,9 @@ class PusherWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList( - cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, + cached_method_name="get_if_user_has_pusher", + list_name="user_ids", + num_args=1, ) async def get_if_users_have_pushers( self, user_ids: Iterable[str] @@ -263,7 +265,8 @@ class PusherWorkerStore(SQLBaseStore): params_by_room = {} for row in res: params_by_room[row["room_id"]] = ThrottleParams( - row["last_sent_ts"], row["throttle_ms"], + row["last_sent_ts"], + row["throttle_ms"], ) return params_by_room diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e4843a202c..43c852c96c 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -160,7 +160,7 @@ class ReceiptsWorkerStore(SQLBaseStore): Args: room_id: List of room_ids. - to_key: Max stream id to fetch receipts upto. + to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. @@ -189,7 +189,7 @@ class ReceiptsWorkerStore(SQLBaseStore): Args: room_ids: The room id. - to_key: Max stream id to fetch receipts upto. + to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. @@ -208,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None ) -> List[dict]: - """See get_linearized_receipts_for_room - """ + """See get_linearized_receipts_for_room""" def f(txn): if from_key: @@ -304,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - @cached(num_args=2,) + @cached( + num_args=2, + ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None ) -> Dict[str, JsonDict]: @@ -312,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore): to a limit of the latest 100 read receipts. Args: - to_key: Max stream id to fetch receipts upto. + to_key: Max stream id to fetch receipts up to. from_key: Min stream id to fetch receipts from. None fetches from the start. diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8405dd460f..25d8dcb6ab 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py
@@ -79,15 +79,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): # call `find_max_generated_user_id_localpart` each time, which is # expensive if there are many entries. self._user_id_seq = build_sequence_generator( - database.engine, find_max_generated_user_id_localpart, "user_id_seq", + database.engine, + find_max_generated_user_id_localpart, + "user_id_seq", ) - self._account_validity = hs.config.account_validity - if hs.config.run_background_tasks and self._account_validity.enabled: - self._clock.call_later( - 0.0, self._set_expiration_date_when_missing, + self._account_validity_enabled = hs.config.account_validity_enabled + if self._account_validity_enabled: + self._account_validity_period = hs.config.account_validity_period + self._account_validity_startup_job_max_delta = ( + hs.config.account_validity_startup_job_max_delta ) + if hs.config.run_background_tasks: + self._clock.call_later( + 0.0, + self._set_expiration_date_when_missing, + ) + # Create a background job for culling expired 3PID validity tokens if hs.config.run_background_tasks: self._clock.looping_call( @@ -110,6 +119,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "creation_ts", "user_type", "deactivated", + "shadow_banned", ], allow_none=True, desc="get_user_by_id", @@ -183,6 +193,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): expiration_ts: int, email_sent: bool, renewal_token: Optional[str] = None, + token_used_ts: Optional[int] = None, ) -> None: """Updates the account validity properties of the given account, with the given values. @@ -196,6 +207,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): period. renewal_token: Renewal token the user can use to extend the validity of their account. Defaults to no token. + token_used_ts: A timestamp of when the current token was used to renew + the account. """ def set_account_validity_for_user_txn(txn): @@ -207,6 +220,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "expiration_ts_ms": expiration_ts, "email_sent": email_sent, "renewal_token": renewal_token, + "token_used_ts_ms": token_used_ts, }, ) self._invalidate_cache_and_stream( @@ -217,10 +231,41 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "set_account_validity_for_user", set_account_validity_for_user_txn ) + async def get_expired_users(self): + """Get UserIDs of all expired users. + + Users who are not active, or do not have profile information, are + excluded from the results. + + Returns: + Deferred[List[UserID]]: List of expired user IDs + """ + + def get_expired_users_txn(txn, now_ms): + # We need to use pattern matching as profiles.user_id is confusingly just the + # user's localpart, whereas account_validity.user_id is a full user ID + sql = """ + SELECT av.user_id from account_validity AS av + LEFT JOIN profiles as p + ON av.user_id LIKE '%%' || p.user_id || ':%%' + WHERE expiration_ts_ms <= ? + AND p.active = 1 + """ + txn.execute(sql, (now_ms,)) + rows = txn.fetchall() + + return [UserID.from_string(row[0]) for row in rows] + + res = await self.db_pool.runInteraction( + "get_expired_users", get_expired_users_txn, self._clock.time_msec() + ) + + return res + async def set_renewal_token_for_user( self, user_id: str, renewal_token: str ) -> None: - """Defines a renewal token for a given user. + """Defines a renewal token for a given user, and clears the token_used timestamp. Args: user_id: ID of the user to set the renewal token for. @@ -233,26 +278,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.simple_update_one( table="account_validity", keyvalues={"user_id": user_id}, - updatevalues={"renewal_token": renewal_token}, + updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None}, desc="set_renewal_token_for_user", ) - async def get_user_from_renewal_token(self, renewal_token: str) -> str: - """Get a user ID from a renewal token. + async def get_user_from_renewal_token( + self, renewal_token: str + ) -> Tuple[str, int, Optional[int]]: + """Get a user ID and renewal status from a renewal token. Args: renewal_token: The renewal token to perform the lookup with. Returns: - The ID of the user to which the token belongs. + A tuple of containing the following values: + * The ID of a user to which the token belongs. + * An int representing the user's expiry timestamp as milliseconds since the + epoch, or 0 if the token was invalid. + * An optional int representing the timestamp of when the user renewed their + account timestamp as milliseconds since the epoch. None if the account + has not been renewed using the current token yet. """ - return await self.db_pool.simple_select_one_onecol( + ret_dict = await self.db_pool.simple_select_one( table="account_validity", keyvalues={"renewal_token": renewal_token}, - retcol="user_id", + retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"], desc="get_user_from_renewal_token", ) + return ( + ret_dict["user_id"], + ret_dict["expiration_ts_ms"], + ret_dict["token_used_ts_ms"], + ) + async def get_renewal_token_for_user(self, user_id: str) -> str: """Get the renewal token associated with a given user ID. @@ -291,7 +350,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity.renew_at, + self.config.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: @@ -323,6 +382,54 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="delete_account_validity_for_user", ) + async def get_info_for_users( + self, user_ids: List[str], + ): + """Return the user info for a given set of users + + Args: + user_ids: A list of users to return information about + + Returns: + Deferred[Dict[str, bool]]: A dictionary mapping each user ID to + a dict with the following keys: + * expired - whether this is an expired user + * deactivated - whether this is a deactivated user + """ + # Get information of all our local users + def _get_info_for_users_txn(txn): + rows = [] + + for user_id in user_ids: + sql = """ + SELECT u.name, u.deactivated, av.expiration_ts_ms + FROM users as u + LEFT JOIN account_validity as av + ON av.user_id = u.name + WHERE u.name = ? + """ + + txn.execute(sql, (user_id,)) + row = txn.fetchone() + if row: + rows.append(row) + + return rows + + info_rows = await self.db_pool.runInteraction( + "get_info_for_users", _get_info_for_users_txn + ) + + return { + user_id: { + "expired": ( + expiration is not None and self._clock.time_msec() >= expiration + ), + "deactivated": deactivated == 1, + } + for user_id, deactivated, expiration in info_rows + } + async def is_server_admin(self, user: UserID) -> bool: """Determines if a user is an admin of this homeserver. @@ -369,23 +476,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): """ def set_shadow_banned_txn(txn): + user_id = user.to_string() self.db_pool.simple_update_one_txn( txn, table="users", - keyvalues={"name": user.to_string()}, + keyvalues={"name": user_id}, updatevalues={"shadow_banned": shadow_banned}, ) # In order for this to apply immediately, clear the cache for this user. tokens = self.db_pool.simple_select_onecol_txn( txn, table="access_tokens", - keyvalues={"user_id": user.to_string()}, + keyvalues={"user_id": user_id}, retcol="token", ) for token in tokens: self._invalidate_cache_and_stream( txn, self.get_user_by_access_token, (token,) ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) @@ -951,11 +1060,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): delta equal to 10% of the validity period. """ now_ms = self._clock.time_msec() - expiration_ts = now_ms + self._account_validity.period + expiration_ts = now_ms + self._account_validity_period if use_delta: expiration_ts = self.rand.randrange( - expiration_ts - self._account_validity.startup_job_max_delta, + expiration_ts - self._account_validity_startup_job_max_delta, expiration_ts, ) @@ -1393,7 +1502,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): except self.database_engine.module.IntegrityError: raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE) - if self._account_validity.enabled: + if self._account_validity_enabled: self.set_expiration_date_for_user_txn(txn, user_id) if create_profile_with_displayname: diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index a9fcb5f59c..8db6f1396a 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py
@@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import collections import logging from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, JoinRules from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.storage._base import SQLBaseStore, db_to_json @@ -178,11 +177,13 @@ class RoomWorkerStore(SQLBaseStore): INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR history_visibility = 'world_readable' + join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + OR history_visibility = 'world_readable' ) AND joined_members > 0 """ % { - "published_sql": published_sql + "published_sql": published_sql, + "knock_join_rule": JoinRules.KNOCK, } txn.execute(sql, query_args) @@ -193,8 +194,7 @@ class RoomWorkerStore(SQLBaseStore): ) async def get_room_count(self) -> int: - """Retrieve the total number of rooms. - """ + """Retrieve the total number of rooms.""" def f(txn): sql = "SELECT count(*) FROM rooms" @@ -305,7 +305,7 @@ class RoomWorkerStore(SQLBaseStore): sql = """ SELECT room_id, name, topic, canonical_alias, joined_members, - avatar, history_visibility, joined_members, guest_access + avatar, history_visibility, guest_access, join_rules FROM ( %(published_sql)s ) published @@ -313,7 +313,8 @@ class RoomWorkerStore(SQLBaseStore): INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR history_visibility = 'world_readable' + join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + OR history_visibility = 'world_readable' ) AND joined_members > 0 %(where_clause)s @@ -322,6 +323,7 @@ class RoomWorkerStore(SQLBaseStore): "published_sql": published_sql, "where_clause": where_clause, "dir": "DESC" if forwards else "ASC", + "knock_join_rule": JoinRules.KNOCK, } if limit is not None: @@ -356,6 +358,23 @@ class RoomWorkerStore(SQLBaseStore): desc="is_room_blocked", ) + async def is_room_published(self, room_id: str) -> bool: + """Check whether a room has been published in the local public room + directory. + + Args: + room_id + Returns: + Whether the room is currently published in the room directory + """ + # Get room information + room_info = await self.get_room(room_id) + if not room_info: + return False + + # Check the is_public value + return room_info.get("is_public", False) + async def get_rooms_paginate( self, start: int, @@ -517,7 +536,8 @@ class RoomWorkerStore(SQLBaseStore): return rooms, room_count[0] return await self.db_pool.runInteraction( - "get_rooms_paginate", _get_rooms_paginate_txn, + "get_rooms_paginate", + _get_rooms_paginate_txn, ) @cached(max_entries=10000) @@ -564,6 +584,11 @@ class RoomWorkerStore(SQLBaseStore): Returns: dict[int, int]: "min_lifetime" and "max_lifetime" for this room. """ + # If the room retention feature is disabled, return a policy with no minimum nor + # maximum, in order not to filter out events we should filter out when sending to + # the client. + if not self.config.retention_enabled: + return {"min_lifetime": None, "max_lifetime": None} def get_retention_policy_for_room_txn(txn): txn.execute( @@ -578,7 +603,8 @@ class RoomWorkerStore(SQLBaseStore): return self.db_pool.cursor_to_dict(txn) ret = await self.db_pool.runInteraction( - "get_retention_policy_for_room", get_retention_policy_for_room_txn, + "get_retention_policy_for_room", + get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default @@ -707,7 +733,10 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs async def quarantine_media_by_id( - self, server_name: str, media_id: str, quarantined_by: str, + self, + server_name: str, + media_id: str, + quarantined_by: str, ) -> int: """quarantines a single local or remote media id @@ -961,7 +990,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self.config = hs.config self.db_pool.updates.register_background_update_handler( - "insert_room_retention", self._background_insert_retention, + "insert_room_retention", + self._background_insert_retention, ) self.db_pool.updates.register_background_update_handler( @@ -1033,7 +1063,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore): return False end = await self.db_pool.runInteraction( - "insert_room_retention", _background_insert_retention_txn, + "insert_room_retention", + _background_insert_retention_txn, ) if end: @@ -1044,7 +1075,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int ): - """Background update to go and add room version inforamtion to `rooms` + """Background update to go and add room version information to `rooms` table from `current_state_events` table. """ @@ -1588,7 +1619,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): LIMIT ? OFFSET ? """.format( - where_clause=where_clause, order=order, + where_clause=where_clause, + order=order, ) args += [limit, start] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 92382bed28..a9216ca9ae 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): ): self._known_servers_count = 1 self.hs.get_clock().looping_call( - self._count_known_servers, 60 * 1000, + self._count_known_servers, + 60 * 1000, ) self.hs.get_clock().call_later( - 1000, self._count_known_servers, + 1000, + self._count_known_servers, ) LaterGauge( "synapse_federation_known_servers", @@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): @cached(max_entries=100000) async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]: - """ Get the details of a room roughly suitable for use by the room + """Get the details of a room roughly suitable for use by the room summary extension to /sync. Useful when lazy loading room members. Args: room_id: The room ID to query @@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): async def get_users_who_share_room_with_user( self, user_id: str, cache_context: _CacheContext ) -> Set[str]: - """Returns the set of users who share a room with `user_id` - """ + """Returns the set of users who share a room with `user_id`""" room_ids = await self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate ) @@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore): raise NotImplementedError() @cachedList( - cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids", + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", ) async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]): """For given set of member event_ids check if they point to a join @@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] ) -> List[dict]: - """Get user_id and membership of a set of event IDs. - """ + """Get user_id and membership of a set of event IDs.""" return await self.db_pool.simple_select_many_batch( table="room_memberships", diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index ad875c733a..3907189e29 100644 --- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py +++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs): def run_upgrade(cur, database_engine, *args, **kwargs): cur.execute( - "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),), + "UPDATE remote_media_cache SET last_access_ts = ?", + (int(time.time() * 1000),), ) diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql new file mode 100644
index 0000000000..e744c02fe8 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Add a batch number to track changes to profiles and the + * order they're made in so we can replicate user profiles + * to other hosts as they change + */ +ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL; + +/* + * Index on the batch number so we can get profiles + * by their batch + */ +CREATE INDEX profiles_batch_idx ON profiles(batch); + +/* + * A table to track what batch of user profiles has been + * synced to what profile replication target. + */ +CREATE TABLE profile_replication_status ( + host TEXT NOT NULL, + last_synced_batch BIGINT NOT NULL +); diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql new file mode 100644
index 0000000000..96051ac179 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * A flag saying whether the user owning the profile has been deactivated + * This really belongs on the users table, not here, but the users table + * stores users by their full user_id and profiles stores them by localpart, + * so we can't easily join between the two tables. Plus, the batch number + * realy ought to represent data in this table that has changed. + */ +ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL; \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql new file mode 100644
index 0000000000..7542ab8cbd --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host); \ No newline at end of file diff --git a/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql new file mode 100644
index 0000000000..4836dac16e --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/19account_validity_token_used_ts_ms.sql
@@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Track when users renew their account using the value of the 'renewal_token' column. +-- This field should be set to NULL after a fresh token is generated. +ALTER TABLE account_validity ADD token_used_ts_ms BIGINT; diff --git a/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql b/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql new file mode 100644
index 0000000000..658f55a384 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/24add_knock_members_to_stats.sql
@@ -0,0 +1,17 @@ +/* Copyright 2020 Sorunome + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_stats_current ADD knocked_members INT NOT NULL DEFAULT '0'; +ALTER TABLE room_stats_historical ADD knocked_members BIGINT NOT NULL DEFAULT '0'; diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream ( +CREATE TABLE profile_replication_status ( + host text NOT NULL, + last_synced_batch bigint NOT NULL +); + + + CREATE TABLE profiles ( user_id text NOT NULL, displayname text, - avatar_url text + avatar_url text, + batch bigint, + active smallint DEFAULT 1 NOT NULL ); @@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id); +CREATE INDEX profiles_batch_idx ON profiles USING btree (batch); + + + CREATE INDEX public_room_index ON rooms USING btree (is_public); diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..301c566a70 100644 --- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) ); CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) ); CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL ); -CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) ); +CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) ); CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) ); CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER ); CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) ); @@ -67,11 +67,6 @@ CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT CREATE INDEX user_threepids_user_id ON user_threepids(user_id); CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value ) /* event_search(event_id,room_id,sender,"key",value) */; -CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value'); -CREATE TABLE IF NOT EXISTS 'event_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'event_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'event_search_stat'(id INTEGER PRIMARY KEY, value BLOB); CREATE TABLE guest_access( event_id TEXT NOT NULL, room_id TEXT NOT NULL, guest_access TEXT NOT NULL, UNIQUE (event_id) ); CREATE TABLE history_visibility( event_id TEXT NOT NULL, room_id TEXT NOT NULL, history_visibility TEXT NOT NULL, UNIQUE (event_id) ); CREATE TABLE room_tags( user_id TEXT NOT NULL, room_id TEXT NOT NULL, tag TEXT NOT NULL, content TEXT NOT NULL, CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag) ); @@ -149,11 +144,6 @@ CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_las CREATE TABLE user_directory_stream_pos ( Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, stream_id BIGINT, CHECK (Lock='X') ); CREATE VIRTUAL TABLE user_directory_search USING fts4 ( user_id, value ) /* user_directory_search(user_id,value) */; -CREATE TABLE IF NOT EXISTS 'user_directory_search_content'(docid INTEGER PRIMARY KEY, 'c0user_id', 'c1value'); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segments'(blockid INTEGER PRIMARY KEY, block BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_segdir'(level INTEGER,idx INTEGER,start_block INTEGER,leaves_end_block INTEGER,end_block INTEGER,root BLOB,PRIMARY KEY(level, idx)); -CREATE TABLE IF NOT EXISTS 'user_directory_search_docsize'(docid INTEGER PRIMARY KEY, size BLOB); -CREATE TABLE IF NOT EXISTS 'user_directory_search_stat'(id INTEGER PRIMARY KEY, value BLOB); CREATE TABLE blocked_rooms ( room_id TEXT NOT NULL, user_id TEXT NOT NULL ); CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); CREATE TABLE IF NOT EXISTS "local_media_repository_url_cache"( url TEXT, response_code INTEGER, etag TEXT, expires_ts BIGINT, og TEXT, media_id TEXT, download_ts BIGINT ); @@ -202,6 +192,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id); CREATE INDEX group_invites_u_idx ON group_invites(user_id); CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +CREATE INDEX profiles_batch_idx ON profiles(batch); +CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL ); CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL ); CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp); CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp); diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 3c1e33819b..a7f371732f 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -52,8 +52,7 @@ class _GetStateGroupDelta( # this inherits from EventsWorkerStore because it calls self.get_events class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): - """The parts of StateGroupStore that can be called from workers. - """ + """The parts of StateGroupStore that can be called from workers.""" def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): num_args=1, ) async def _get_state_group_for_events(self, event_ids): - """Returns mapping event_id -> state_group - """ + """Returns mapping event_id -> state_group""" rows = await self.db_pool.simple_select_many_batch( table="event_to_state_groups", column="event_id", @@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): columns=["state_group"], ) self.db_pool.updates.register_background_update_handler( - self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms, + self.DELETE_CURRENT_STATE_UPDATE_NAME, + self._background_remove_left_rooms, ) async def _background_remove_left_rooms(self, progress, batch_size): @@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): - """ Keeps track of the state at a given event. + """Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned a state group (identified by an arbitrary string), which references a diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py
index 356623fc6e..0dbb501f16 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py
@@ -64,7 +64,7 @@ class StateDeltasStore(SQLBaseStore): def get_current_state_deltas_txn(txn): # First we calculate the max stream id that will give us less than # N results. - # We arbitarily limit to 100 stream_id entries to ensure we don't + # We arbitrarily limit to 100 stream_id entries to ensure we don't # select toooo many. sql = """ SELECT stream_id, count(*) @@ -81,7 +81,7 @@ class StateDeltasStore(SQLBaseStore): for stream_id, count in txn: total += count if total > 100: - # We arbitarily limit to 100 entries to ensure we don't + # We arbitrarily limit to 100 entries to ensure we don't # select toooo many. logger.debug( "Clipping current_state_delta_stream rows to stream_id %i", diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d421d18f8d..38adecc78a 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py
@@ -42,6 +42,7 @@ ABSOLUTE_STATS_FIELDS = { "current_state_events", "joined_members", "invited_members", + "knocked_members", "left_members", "banned_members", "local_users_in_room", @@ -1001,7 +1002,9 @@ class StatsStore(StateDeltasStore): ORDER BY {order_by_column} {order} LIMIT ? OFFSET ? """.format( - sql_base=sql_base, order_by_column=order_by_column, order=order, + sql_base=sql_base, + order_by_column=order_by_column, + order=order, ) args += [limit, start] diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e3b9ff5ca6..91f8abb67d 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py
@@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): AND e.stream_ordering > ? AND e.stream_ordering <= ? ORDER BY e.stream_ordering ASC """ - txn.execute(sql, (user_id, min_from_id, max_to_id,)) + txn.execute( + sql, + ( + user_id, + min_from_id, + max_to_id, + ), + ) rows = [ _EventDictReturn(event_id, None, stream_ordering) @@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): return "t%d-%d" % (topo, token) def get_stream_id_for_event_txn( - self, txn: LoggingTransaction, event_id: str, allow_none=False, + self, + txn: LoggingTransaction, + event_id: str, + allow_none=False, ) -> int: return self.db_pool.simple_select_one_onecol_txn( txn=txn, @@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) async def get_position_for_event(self, event_id: str) -> PersistedEventPosition: - """Get the persisted position for an event - """ + """Get the persisted position for an event""" row = await self.db_pool.simple_select_one( table="events", keyvalues={"event_id": event_id}, @@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): ) -> Tuple[int, List[EventBase]]: """Get all new events - Returns all events with from_id < stream_ordering <= current_id. + Returns all events with from_id < stream_ordering <= current_id. - Args: - from_id: the stream_ordering of the last event we processed - current_id: the stream_ordering of the most recently processed event - limit: the maximum number of events to return + Args: + from_id: the stream_ordering of the last event we processed + current_id: the stream_ordering of the most recently processed event + limit: the maximum number of events to return - Returns: - A tuple of (next_id, events), where `next_id` is the next value to - pass as `from_id` (it will either be the stream_ordering of the - last returned event, or, if fewer than `limit` events were found, - the `current_id`). - """ + Returns: + A tuple of (next_id, events), where `next_id` is the next value to + pass as `from_id` (it will either be the stream_ordering of the + last returned event, or, if fewer than `limit` events were found, + the `current_id`). + """ def get_all_new_events_stream_txn(txn): sql = ( @@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta): @cached() async def get_id_for_instance(self, instance_name: str) -> int: - """Get a unique, immutable ID that corresponds to the given Synapse worker instance. - """ + """Get a unique, immutable ID that corresponds to the given Synapse worker instance.""" def _get_id_for_instance_txn(txn): instance_id = self.db_pool.simple_select_one_onecol_txn( diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index cea595ff19..b921d63d30 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -64,8 +64,7 @@ class TransactionWorkerStore(SQLBaseStore): class TransactionStore(TransactionWorkerStore): - """A collection of queries for handling PDUs. - """ + """A collection of queries for handling PDUs.""" def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -198,7 +197,7 @@ class TransactionStore(TransactionWorkerStore): retry_interval: int, ) -> None: """Sets the current retry timings for a given destination. - Both timings should be zero if retrying is no longer occuring. + Both timings should be zero if retrying is no longer occurring. Args: destination @@ -299,7 +298,10 @@ class TransactionStore(TransactionWorkerStore): ) async def store_destination_rooms_entries( - self, destinations: Iterable[str], room_id: str, stream_ordering: int, + self, + destinations: Iterable[str], + room_id: str, + stream_ordering: int, ) -> None: """ Updates or creates `destination_rooms` entries in batch for a single event. @@ -394,7 +396,9 @@ class TransactionStore(TransactionWorkerStore): ) async def get_catch_up_room_event_ids( - self, destination: str, last_successful_stream_ordering: int, + self, + destination: str, + last_successful_stream_ordering: int, ) -> List[str]: """ Returns at most 50 event IDs and their corresponding stream_orderings @@ -418,7 +422,9 @@ class TransactionStore(TransactionWorkerStore): @staticmethod def _get_catch_up_room_event_ids_txn( - txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int, + txn: LoggingTransaction, + destination: str, + last_successful_stream_ordering: int, ) -> List[str]: q = """ SELECT event_id FROM destination_rooms @@ -429,7 +435,8 @@ class TransactionStore(TransactionWorkerStore): LIMIT 50 """ txn.execute( - q, (destination, last_successful_stream_ordering), + q, + (destination, last_successful_stream_ordering), ) event_ids = [row[0] for row in txn] return event_ids diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 79b7ece330..5473ec1485 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py
@@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore): """ async def create_ui_auth_session( - self, clientdict: JsonDict, uri: str, method: str, description: str, + self, + clientdict: JsonDict, + uri: str, + method: str, + description: str, ) -> UIAuthSessionData: """ Creates a new user interactive authentication session. @@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore): return UIAuthSessionData(session_id, **result) async def mark_ui_auth_stage_complete( - self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + self, + session_id: str, + stage_type: str, + result: Union[str, bool, JsonDict], ): """ Mark a session stage as completed. @@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) async def add_user_agent_ip_to_ui_auth_session( - self, session_id: str, user_agent: str, ip: str, + self, + session_id: str, + user_agent: str, + ip: str, ): - """Add the given user agent / IP to the tracking table - """ + """Add the given user agent / IP to the tracking table""" await self.db_pool.simple_upsert( table="ui_auth_sessions_ips", keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, @@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore): ) async def get_user_agents_ips_to_ui_auth_session( - self, session_id: str, + self, + session_id: str, ) -> List[Tuple[str, str]]: """Get the given user agents / IPs used during the ui auth process diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 7b9729da09..02ee15676c 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return len(users_to_work_on) async def is_room_world_readable_or_publicly_joinable(self, room_id): - """Check if the room is either world_readable or publically joinable - """ + """Check if the room is either world_readable or publically joinable""" # Create a state filter that only queries join and history state event types_to_filter = ( @@ -516,8 +515,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): ) async def delete_all_from_user_dir(self) -> None: - """Delete the entire user directory - """ + """Delete the entire user directory""" def _delete_all_from_user_dir_txn(txn): txn.execute("DELETE FROM user_directory") @@ -558,6 +556,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._prefer_local_users_in_search = ( + hs.config.user_directory_search_prefer_local_users + ) + self._server_name = hs.config.server_name + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( @@ -709,7 +712,13 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return {row["room_id"] for row in rows} - async def get_user_directory_stream_pos(self) -> int: + async def get_user_directory_stream_pos(self) -> Optional[int]: + """ + Get the stream ID of the user directory stream. + + Returns: + The stream token or None if the initial background update hasn't happened yet. + """ return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos", keyvalues={}, @@ -750,9 +759,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) """ + # We allow manipulating the ranking algorithm by injecting statements + # based on config options. + additional_ordering_statements = [] + ordering_arguments = () + if isinstance(self.database_engine, PostgresEngine): full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)" + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + # We order by rank and then if they have profile info # The ranking algorithm is hand tweaked for "best" results. Broadly # the idea is we give a higher weight to exact matches. @@ -763,7 +787,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND vector @@ to_tsquery('simple', ?) ORDER BY (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) @@ -783,33 +807,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): 8 ) ) + %(order_case_statements)s DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, + """ % { + "where_clause": where_clause, + "order_case_statements": " ".join(additional_ordering_statements), + } + args = ( + join_args + + (full_query, exact_query, prefix_query) + + ordering_arguments + + (limit + 1,) ) - args = join_args + (full_query, exact_query, prefix_query, limit + 1) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) + # If enabled, this config option will rank local users higher than those on + # remote instances. + if self._prefer_local_users_in_search: + # This statement checks whether a given user's user ID contains a server name + # that matches the local server + # + # Note that we need to include a comma at the end for valid SQL + statement = "user_id LIKE ? DESC," + additional_ordering_statements.append(statement) + + ordering_arguments += ("%:" + self._server_name,) + sql = """ SELECT d.user_id AS user_id, display_name, avatar_url FROM user_directory_search as t INNER JOIN user_directory AS d USING (user_id) WHERE - %s + %(where_clause)s AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, + %(order_statements)s display_name IS NULL, avatar_url IS NULL LIMIT ? - """ % ( - where_clause, - ) - args = join_args + (search_query, limit + 1) + """ % { + "where_clause": where_clause, + "order_statements": " ".join(additional_ordering_statements), + } + args = join_args + (search_query,) + ordering_arguments + (limit + 1,) else: # This should be unreachable. raise Exception("Unrecognized database engine") diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index acb24e33af..1fd333b707 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -27,7 +27,7 @@ MAX_STATE_DELTA_HOPS = 100 class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud + """Defines functions related to state groups needed to run the state background updates. """ diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 89cdc84a9c..b16b9905d8 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py
@@ -48,8 +48,7 @@ class _GetStateGroupDelta( class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): - """A data store for fetching/storing state groups. - """ + """A data store for fetching/storing state groups.""" def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) @@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", 500000, + "*stateGroupMembersCache*", + 500000, ) def get_max_state_group_txn(txn: Cursor): diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 035f9ea6e9..d15ccfacde 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine @@ -28,11 +27,8 @@ def create_engine(database_config) -> BaseDatabaseEngine: return Sqlite3Engine(sqlite3, database_config) if name == "psycopg2": - # pypy requires psycopg2cffi rather than psycopg2 - if platform.python_implementation() == "PyPy": - import psycopg2cffi as psycopg2 # type: ignore - else: - import psycopg2 # type: ignore + # Note that psycopg2cffi-compat provides the psycopg2 module on pypy. + import psycopg2 # type: ignore return PostgresEngine(psycopg2, database_config) diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index d6d632dc10..cca839c70f 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py
@@ -94,14 +94,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): @property @abc.abstractmethod def server_version(self) -> str: - """Gets a string giving the server version. For example: '3.22.0' - """ + """Gets a string giving the server version. For example: '3.22.0'""" ... @abc.abstractmethod def in_transaction(self, conn: Connection) -> bool: - """Whether the connection is currently in a transaction. - """ + """Whether the connection is currently in a transaction.""" ... @abc.abstractmethod diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 7719ac32f7..80a3558aec 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py
@@ -138,8 +138,7 @@ class PostgresEngine(BaseDatabaseEngine): @property def supports_using_any_list(self): - """Do we support using `a = ANY(?)` and passing a list - """ + """Do we support using `a = ANY(?)` and passing a list""" return True def is_deadlock(self, error): diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 5db0f0b520..b87e7798da 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py
@@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform import struct import threading import typing @@ -28,7 +29,15 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): super().__init__(database_module, database_config) database = database_config.get("args", {}).get("database") - self._is_in_memory = database in (None, ":memory:",) + self._is_in_memory = database in ( + None, + ":memory:", + ) + + if platform.python_implementation() == "PyPy": + # pypy's sqlite3 module doesn't handle bytearrays, convert them + # back to bytes. + database_module.register_adapter(bytearray, lambda array: bytes(array)) # The current max state_group, or None if we haven't looked # in the DB yet. @@ -57,8 +66,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): @property def supports_using_any_list(self): - """Do we support using `a = ANY(?)` and passing a list - """ + """Do we support using `a = ANY(?)` and passing a list""" return False def check_database(self, db_conn, allow_outdated_version: bool = False): diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 61fc49c69c..3a0d6fb32e 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py
@@ -411,8 +411,8 @@ class EventsPersistenceStorage: ) for room_id, ev_ctx_rm in events_by_room.items(): - latest_event_ids = await self.main_store.get_latest_event_ids_in_room( - room_id + latest_event_ids = ( + await self.main_store.get_latest_event_ids_in_room(room_id) ) new_latest_event_ids = await self._calculate_new_extremities( room_id, ev_ctx_rm, latest_event_ids @@ -889,7 +889,8 @@ class EventsPersistenceStorage: continue logger.debug( - "Not dropping as too new and not in new_senders: %s", new_senders, + "Not dropping as too new and not in new_senders: %s", + new_senders, ) return new_latest_event_ids @@ -1004,7 +1005,10 @@ class EventsPersistenceStorage: remote_event_ids = [ event_id - for (typ, state_key,), event_id in current_state.items() + for ( + typ, + state_key, + ), event_id in current_state.items() if typ == EventTypes.Member and not self.is_mine_id(state_key) ] rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19bae..6c3c2da520 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py
@@ -113,7 +113,7 @@ def prepare_database( # which should be empty. if config is None: raise ValueError( - "config==None in prepare_database, but databse is not empty" + "config==None in prepare_database, but database is not empty" ) # if it's a worker app, refuse to upgrade the database, to avoid multiple @@ -425,7 +425,10 @@ def _upgrade_existing_database( # We don't support using the same file name in the same delta version. raise PrepareDatabaseException( "Found multiple delta files with the same name in v%d: %s" - % (v, duplicates,) + % ( + v, + duplicates, + ) ) # We sort to ensure that we apply the delta files in a consistent @@ -532,7 +535,8 @@ def _apply_module_schema_files( names_and_streams: the names and streams of schemas to be applied """ cur.execute( - "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), + "SELECT file FROM applied_module_schemas WHERE module_name = ?", + (modname,), ) applied_deltas = {d for d, in cur} for (name, stream) in names_and_streams: @@ -619,9 +623,9 @@ def _get_or_create_schema_state( txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() - current_version = int(row[0]) if row else None - if current_version: + if row is not None: + current_version = int(row[0]) txn.execute( "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 6c359c1aae..3c4908865f 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py
@@ -26,15 +26,13 @@ logger = logging.getLogger(__name__) class PurgeEventsStorage: - """High level interface for purging rooms and event history. - """ + """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores async def purge_room(self, room_id: str) -> None: - """Deletes all record of a room - """ + """Deletes all record of a room""" state_groups_to_delete = await self.stores.main.purge_room(room_id) await self.stores.state.purge_room_state(room_id, state_groups_to_delete) diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 31ccbf23dc..d179a41884 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py
@@ -340,8 +340,7 @@ class StateFilter: class StateGroupStorage: - """High level interface to fetching state for event. - """ + """High level interface to fetching state for event.""" def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores @@ -400,7 +399,7 @@ class StateGroupStorage: async def get_state_groups( self, room_id: str, event_ids: Iterable[str] ) -> Dict[int, List[EventBase]]: - """ Get the state groups for the given list of event_ids + """Get the state groups for the given list of event_ids Args: room_id: ID of the room for these events. diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union from typing_extensions import Protocol @@ -20,23 +20,44 @@ from typing_extensions import Protocol Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ +_Parameters = Union[Sequence[Any], Mapping[str, Any]] + class Cursor(Protocol): - def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + def execute(self, sql: str, parameters: _Parameters = ...) -> Any: ... - def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any: ... - def fetchall(self) -> List[Tuple]: + def fetchone(self) -> Optional[Tuple]: + ... + + def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: ... - def fetchone(self) -> Tuple: + def fetchall(self) -> List[Tuple]: ... @property - def description(self) -> Any: - return None + def description( + self, + ) -> Optional[ + Sequence[ + # Note that this is an approximate typing based on sqlite3 and other + # drivers, and may not be entirely accurate. + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: + ... @property def rowcount(self) -> int: @@ -59,7 +80,7 @@ class Connection(Protocol): def commit(self) -> None: ... - def rollback(self, *args, **kwargs) -> None: + def rollback(self) -> None: ... def __enter__(self) -> "Connection": diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 71ef5a72dc..d4643c4fdf 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -245,7 +245,7 @@ class MultiWriterIdGenerator: # and b) noting that if we have seen a run of persisted positions # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7). # - # Note: There is no guarentee that the IDs generated by the sequence + # Note: There is no guarantee that the IDs generated by the sequence # will be gapless; gaps can form when e.g. a transaction was rolled # back. This means that sometimes we won't be able to skip forward the # position even though everything has been persisted. However, since @@ -277,7 +277,9 @@ class MultiWriterIdGenerator: self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, tables: List[Tuple[str, str, str]], + self, + db_conn, + tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -364,7 +366,10 @@ class MultiWriterIdGenerator: rows.sort() with self._lock: - for (instance, stream_id,) in rows: + for ( + instance, + stream_id, + ) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) @@ -418,7 +423,7 @@ class MultiWriterIdGenerator: # bother, as nothing will read it). # # We only do this on the success path so that the persisted current - # position points to a persited row with the correct instance name. + # position points to a persisted row with the correct instance name. if self._writers: txn.call_after( run_as_background_process, @@ -481,8 +486,7 @@ class MultiWriterIdGenerator: return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: - """Returns the position of the given writer. - """ + """Returns the position of the given writer.""" # If we don't have an entry for the given instance name, we assume it's a # new writer. @@ -509,7 +513,7 @@ class MultiWriterIdGenerator: } def advance(self, instance_name: str, new_id: int): - """Advance the postion of the named writer to the given ID, if greater + """Advance the position of the named writer to the given ID, if greater than existing entry. """ @@ -581,8 +585,7 @@ class MultiWriterIdGenerator: break def _update_stream_positions_table_txn(self, txn: Cursor): - """Update the `stream_positions` table with newly persisted position. - """ + """Update the `stream_positions` table with newly persisted position.""" if not self._writers: return @@ -622,8 +625,7 @@ class _AsyncCtxManagerWrapper: @attr.s(slots=True) class _MultiWriterCtxManager: - """Async context manager returned by MultiWriterIdGenerator - """ + """Async context manager returned by MultiWriterIdGenerator""" id_gen = attr.ib(type=MultiWriterIdGenerator) multiple_ids = attr.ib(type=Optional[int], default=None) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 0ec4dc2918..3ea637b281 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator): def get_next_id_txn(self, txn: Cursor) -> int: txn.execute("SELECT nextval(?)", (self._sequence_name,)) - return txn.fetchone()[0] + fetch_res = txn.fetchone() + assert fetch_res is not None + return fetch_res[0] def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: txn.execute( @@ -122,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator): stream_name: Optional[str] = None, positive: bool = True, ): - """See SequenceGenerator.check_consistency for docstring. - """ + """See SequenceGenerator.check_consistency for docstring.""" txn = db_conn.cursor(txn_name="sequence.check_consistency") @@ -147,7 +148,9 @@ class PostgresSequenceGenerator(SequenceGenerator): txn.execute( "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) - last_value, is_called = txn.fetchone() + fetch_res = txn.fetchone() + assert fetch_res is not None + last_value, is_called = fetch_res # If we have an associated stream check the stream_positions table. max_in_stream_positions = None diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py new file mode 100644
index 0000000000..1453d04571 --- /dev/null +++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py new file mode 100644
index 0000000000..4589e4539b --- /dev/null +++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,971 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import email.utils +import logging +from typing import Dict, List, Optional, Tuple + +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset +from synapse.api.errors import SynapseError +from synapse.config._base import ConfigError +from synapse.events import EventBase +from synapse.module_api import ModuleApi +from synapse.types import Requester, StateMap, UserID, get_domain_from_id + +logger = logging.getLogger(__name__) + +ACCESS_RULES_TYPE = "im.vector.room.access_rules" + + +class AccessRules: + DIRECT = "direct" + RESTRICTED = "restricted" + UNRESTRICTED = "unrestricted" + + +VALID_ACCESS_RULES = ( + AccessRules.DIRECT, + AccessRules.RESTRICTED, + AccessRules.UNRESTRICTED, +) + +# Rules to which we need to apply the power levels restrictions. +# +# These are all of the rules that neither: +# * forbid users from joining based on a server blacklist (which means that there +# is no need to apply power level restrictions), nor +# * target direct chats (since we allow both users to be room admins in this case). +# +# The power-level restrictions, when they are applied, prevent the following: +# * the default power level for users (users_default) being set to anything other than 0. +# * a non-default power level being assigned to any user which would be forbidden from +# joining a restricted room. +RULES_WITH_RESTRICTED_POWER_LEVELS = (AccessRules.UNRESTRICTED,) + + +class RoomAccessRules(object): + """Implementation of the ThirdPartyEventRules module API that allows federation admins + to define custom rules for specific events and actions. + Implements the custom behaviour for the "im.vector.room.access_rules" state event. + + Takes a config in the format: + + third_party_event_rules: + module: third_party_rules.RoomAccessRules + config: + # List of domains (server names) that can't be invited to rooms if the + # "restricted" rule is set. Defaults to an empty list. + domains_forbidden_when_restricted: [] + + # Identity server to use when checking the HS an email address belongs to + # using the /info endpoint. Required. + id_server: "vector.im" + + # Enable freezing a room when the last room admin leaves. + # Note that the departing admin must be a local user in order for this feature + # to work. + freeze_room_with_no_admin: false + + Don't forget to consider if you can invite users from your own domain. + """ + + def __init__( + self, config: Dict, module_api: ModuleApi, + ): + self.id_server = config["id_server"] + self.module_api = module_api + + self.domains_forbidden_when_restricted = config.get( + "domains_forbidden_when_restricted", [] + ) + + self.freeze_room_with_no_admin = config.get("freeze_room_with_no_admin", False) + + @staticmethod + def parse_config(config: Dict) -> Dict: + """Parses and validates the options specified in the homeserver config. + + Args: + config: The config dict. + + Returns: + The config dict. + + Raises: + ConfigError: If there was an issue with the provided module configuration. + """ + if "id_server" not in config: + raise ConfigError("No IS for event rules TchapEventRules") + + return config + + async def on_create_room( + self, requester: Requester, config: Dict, is_requester_admin: bool, + ) -> bool: + """Implements synapse.events.ThirdPartyEventRules.on_create_room. + + Checks if a im.vector.room.access_rules event is being set during room creation. + If yes, make sure the event is correct. Otherwise, append an event with the + default rule to the initial state. + + Checks if a m.rooms.power_levels event is being set during room creation. + If yes, make sure the event is allowed. Otherwise, set power_level_content_override + in the config dict to our modified version of the default room power levels. + + Args: + requester: The user who is making the createRoom request. + config: The createRoom config dict provided by the user. + is_requester_admin: Whether the requester is a Synapse admin. + + Returns: + Whether the request is allowed. + + Raises: + SynapseError: If the createRoom config dict is invalid or its contents blocked. + """ + is_direct = config.get("is_direct") + preset = config.get("preset") + access_rule = None + join_rule = None + + # If there's a rules event in the initial state, check if it complies with the + # spec for im.vector.room.access_rules and deny the request if not. + for event in config.get("initial_state", []): + if event["type"] == ACCESS_RULES_TYPE: + access_rule = event["content"].get("rule") + + # Make sure the event has a valid content. + if access_rule is None: + raise SynapseError(400, "Invalid access rule") + + # Make sure the rule name is valid. + if access_rule not in VALID_ACCESS_RULES: + raise SynapseError(400, "Invalid access rule") + + if (is_direct and access_rule != AccessRules.DIRECT) or ( + access_rule == AccessRules.DIRECT and not is_direct + ): + raise SynapseError(400, "Invalid access rule") + + if event["type"] == EventTypes.JoinRules: + join_rule = event["content"].get("join_rule") + + if access_rule is None: + # If there's no access rules event in the initial state, create one with the + # default setting. + if is_direct: + default_rule = AccessRules.DIRECT + else: + # If the default value for non-direct chat changes, we should make another + # case here for rooms created with either a "public" join_rule or the + # "public_chat" preset to make sure those keep defaulting to "restricted" + default_rule = AccessRules.RESTRICTED + + if not config.get("initial_state"): + config["initial_state"] = [] + + config["initial_state"].append( + { + "type": ACCESS_RULES_TYPE, + "state_key": "", + "content": {"rule": default_rule}, + } + ) + + access_rule = default_rule + + # Check that the preset in use is compatible with the access rule, whether it's + # user-defined or the default. + # + # Direct rooms may not have their join_rules set to JoinRules.PUBLIC. + if ( + join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT + ) and access_rule == AccessRules.DIRECT: + raise SynapseError(400, "Invalid access rule") + + default_power_levels = self._get_default_power_levels( + requester.user.to_string() + ) + + # Check if the creator can override values for the power levels. + allowed = self._is_power_level_content_allowed( + config.get("power_level_content_override", {}), + access_rule, + default_power_levels, + ) + if not allowed: + raise SynapseError(400, "Invalid power levels content override") + + custom_user_power_levels = config.get("power_level_content_override") + + # Second loop for events we need to know the current rule to process. + for event in config.get("initial_state", []): + if event["type"] == EventTypes.PowerLevels: + allowed = self._is_power_level_content_allowed( + event["content"], access_rule, default_power_levels + ) + if not allowed: + raise SynapseError(400, "Invalid power levels content") + + custom_user_power_levels = event["content"] + if custom_user_power_levels: + # If the user is using their own power levels, but failed to provide an expected + # key in the power levels content dictionary, fill it in from the defaults instead + for key, value in default_power_levels.items(): + custom_user_power_levels.setdefault(key, value) + else: + # If power levels were not overridden by the user, completely override with the + # defaults instead + config["power_level_content_override"] = default_power_levels + + return True + + # If power levels are not overridden by the user during room creation, the following + # rules are used instead. Changes from Synapse's default power levels are noted. + # + # The same power levels are currently applied regardless of room preset. + @staticmethod + def _get_default_power_levels(user_id: str) -> Dict: + return { + "users": {user_id: 100}, + "users_default": 0, + "events": { + EventTypes.Name: 50, + EventTypes.PowerLevels: 100, + EventTypes.RoomHistoryVisibility: 100, + EventTypes.CanonicalAlias: 50, + EventTypes.RoomAvatar: 50, + EventTypes.Tombstone: 100, + EventTypes.ServerACL: 100, + EventTypes.RoomEncryption: 100, + }, + "events_default": 0, + "state_default": 100, # Admins should be the only ones to perform other tasks + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 50, # All rooms should require mod to invite, even private + } + + async def check_threepid_can_be_invited( + self, medium: str, address: str, state_events: StateMap[EventBase], + ) -> bool: + """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited. + + Check if a threepid can be invited to the room via a 3PID invite given the current + rules and the threepid's address, by retrieving the HS it's mapped to from the + configured identity server, and checking if we can invite users from it. + + Args: + medium: The medium of the threepid. + address: The address of the threepid. + state_events: A dict mapping (event type, state key) to state event. + State events in the room the threepid is being invited to. + + Returns: + Whether the threepid invite is allowed. + """ + rule = self._get_rule_from_state(state_events) + + if medium != "email": + return False + + if rule != AccessRules.RESTRICTED: + # Only "restricted" requires filtering 3PID invites. We don't need to do + # anything for "direct" here, because only "restricted" requires filtering + # based on the HS the address is mapped to. + return True + + parsed_address = email.utils.parseaddr(address)[1] + if parsed_address != address: + # Avoid reproducing the security issue described here: + # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2 + # It's probably not worth it but let's just be overly safe here. + return False + + # Get the HS this address belongs to from the identity server. + res = await self.module_api.http_client.get_json( + "https://%s/_matrix/identity/api/v1/info" % (self.id_server,), + {"medium": medium, "address": address}, + ) + + # Look for a domain that's not forbidden from being invited. + if not res.get("hs"): + return False + if res.get("hs") in self.domains_forbidden_when_restricted: + return False + + return True + + async def check_event_allowed( + self, event: EventBase, state_events: StateMap[EventBase], + ) -> bool: + """Implements synapse.events.ThirdPartyEventRules.check_event_allowed. + + Checks the event's type and the current rule and calls the right function to + determine whether the event can be allowed. + + Args: + event: The event to check. + state_events: A dict mapping (event type, state key) to state event. + State events in the room the event originated from. + + Returns: + True if the event can be allowed, False otherwise. + """ + if event.type == ACCESS_RULES_TYPE: + return await self._on_rules_change(event, state_events) + + # We need to know the rule to apply when processing the event types below. + rule = self._get_rule_from_state(state_events) + + if event.type == EventTypes.PowerLevels: + return self._is_power_level_content_allowed( + event.content, rule, on_room_creation=False + ) + + if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite: + return await self._on_membership_or_invite(event, rule, state_events) + + if event.type == EventTypes.JoinRules: + return self._on_join_rule_change(event, rule) + + if event.type == EventTypes.RoomAvatar: + return self._on_room_avatar_change(event, rule) + + if event.type == EventTypes.Name: + return self._on_room_name_change(event, rule) + + if event.type == EventTypes.Topic: + return self._on_room_topic_change(event, rule) + + return True + + async def check_visibility_can_be_modified( + self, room_id: str, state_events: StateMap[EventBase], new_visibility: str + ) -> bool: + """Implements + synapse.events.ThirdPartyEventRules.check_visibility_can_be_modified + + Determines whether a room can be published, or removed from, the public room + list. A room is published if its visibility is set to "public". Otherwise, + its visibility is "private". A room with access rule other than "restricted" + may not be published. + + Args: + room_id: The ID of the room. + state_events: A dict mapping (event type, state key) to state event. + State events in the room. + new_visibility: The new visibility state. Either "public" or "private". + + Returns: + Whether the room is allowed to be published to, or removed from, the public + rooms directory. + """ + # We need to know the rule to apply when processing the event types below. + rule = self._get_rule_from_state(state_events) + + # Allow adding a room to the public rooms list only if it is restricted + if new_visibility == "public": + return rule == AccessRules.RESTRICTED + + # By default a room is created as "restricted", meaning it is allowed to be + # published to the public rooms directory. + return True + + async def _on_rules_change( + self, event: EventBase, state_events: StateMap[EventBase] + ): + """Checks whether an im.vector.room.access_rules event is forbidden or allowed. + + Args: + event: The im.vector.room.access_rules event. + state_events: A dict mapping (event type, state key) to state event. + State events in the room before the event was sent. + Returns: + True if the event can be allowed, False otherwise. + """ + new_rule = event.content.get("rule") + + # Check for invalid values. + if new_rule not in VALID_ACCESS_RULES: + return False + + # Make sure we don't apply "direct" if the room has more than two members. + if new_rule == AccessRules.DIRECT: + existing_members, threepid_tokens = self._get_members_and_tokens_from_state( + state_events + ) + + if len(existing_members) > 2 or len(threepid_tokens) > 1: + return False + + if new_rule != AccessRules.RESTRICTED: + # Block this change if this room is currently listed in the public rooms + # directory + if await self.module_api.public_room_list_manager.room_is_in_public_room_list( + event.room_id + ): + return False + + prev_rules_event = state_events.get((ACCESS_RULES_TYPE, "")) + + # Now that we know the new rule doesn't break the "direct" case, we can allow any + # new rule in rooms that had none before. + if prev_rules_event is None: + return True + + prev_rule = prev_rules_event.content.get("rule") + + # Currently, we can only go from "restricted" to "unrestricted". + return ( + prev_rule == AccessRules.RESTRICTED and new_rule == AccessRules.UNRESTRICTED + ) + + async def _on_membership_or_invite( + self, event: EventBase, rule: str, state_events: StateMap[EventBase], + ) -> bool: + """Applies the correct rule for incoming m.room.member and + m.room.third_party_invite events. + + Args: + event: The event to check. + rule: The name of the rule to apply. + state_events: A dict mapping (event type, state key) to state event. + The state of the room before the event was sent. + + Returns: + True if the event can be allowed, False otherwise. + """ + if rule == AccessRules.RESTRICTED: + ret = self._on_membership_or_invite_restricted(event) + elif rule == AccessRules.UNRESTRICTED: + ret = self._on_membership_or_invite_unrestricted(event, state_events) + elif rule == AccessRules.DIRECT: + ret = self._on_membership_or_invite_direct(event, state_events) + else: + # We currently apply the default (restricted) if we don't know the rule, we + # might want to change that in the future. + ret = self._on_membership_or_invite_restricted(event) + + if event.type == "m.room.member": + # If this is an admin leaving, and they are the last admin in the room, + # raise the power levels of the room so that the room is 'frozen'. + # + # We have to freeze the room by puppeting an admin user, which we can + # only do for local users + if ( + self.freeze_room_with_no_admin + and self._is_local_user(event.sender) + and event.membership == Membership.LEAVE + ): + await self._freeze_room_if_last_admin_is_leaving(event, state_events) + + return ret + + async def _freeze_room_if_last_admin_is_leaving( + self, event: EventBase, state_events: StateMap[EventBase] + ): + power_level_state_event = state_events.get( + (EventTypes.PowerLevels, "") + ) # type: EventBase + if not power_level_state_event: + return + power_level_content = power_level_state_event.content + + # Do some validation checks on the power level state event + if ( + not isinstance(power_level_content, dict) + or "users" not in power_level_content + or not isinstance(power_level_content["users"], dict) + ): + # We can't use this power level event to determine whether the room should be + # frozen. Bail out. + return + + user_id = event.get("sender") + if not user_id: + return + + # Get every admin user defined in the room's state + admin_users = { + user + for user, power_level in power_level_content["users"].items() + if power_level >= 100 + } + + if user_id not in admin_users: + # This user is not an admin, ignore them + return + + if any( + event_type == EventTypes.Member + and event.membership in [Membership.JOIN, Membership.INVITE] + and state_key in admin_users + and state_key != user_id + for (event_type, state_key), event in state_events.items() + ): + # There's another admin user in, or invited to, the room + return + + # Freeze the room by raising the required power level to send events to 100 + logger.info("Freezing room '%s'", event.room_id) + + # Modify the existing power levels to raise all required types to 100 + # + # This changes a power level state event's content from something like: + # { + # "redact": 50, + # "state_default": 50, + # "ban": 50, + # "notifications": { + # "room": 50 + # }, + # "events": { + # "m.room.avatar": 50, + # "m.room.encryption": 50, + # "m.room.canonical_alias": 50, + # "m.room.name": 50, + # "im.vector.modular.widgets": 50, + # "m.room.topic": 50, + # "m.room.tombstone": 50, + # "m.room.history_visibility": 100, + # "m.room.power_levels": 100 + # }, + # "users_default": 0, + # "events_default": 0, + # "users": { + # "@admin:example.com": 100, + # }, + # "kick": 50, + # "invite": 0 + # } + # + # to + # + # { + # "redact": 100, + # "state_default": 100, + # "ban": 100, + # "notifications": { + # "room": 50 + # }, + # "events": {} + # "users_default": 0, + # "events_default": 100, + # "users": { + # "@admin:example.com": 100, + # }, + # "kick": 100, + # "invite": 100 + # } + new_content = {} + for key, value in power_level_content.items(): + # Do not change "users_default", as that key specifies the default power + # level of new users + if isinstance(value, int) and key != "users_default": + value = 100 + new_content[key] = value + + # Set some values in case they are missing from the original + # power levels event content + new_content.update( + { + # Clear out any special-cased event keys + "events": {}, + # Ensure state_default and events_default keys exist and are 100. + # Otherwise a lower PL user could potentially send state events that + # aren't explicitly mentioned elsewhere in the power level dict + "state_default": 100, + "events_default": 100, + # Membership events default to 50 if they aren't present. Set them + # to 100 here, as they would be set to 100 if they were present anyways + "ban": 100, + "kick": 100, + "invite": 100, + "redact": 100, + } + ) + + await self.module_api.create_and_send_event_into_room( + { + "room_id": event.room_id, + "sender": user_id, + "type": EventTypes.PowerLevels, + "content": new_content, + "state_key": "", + } + ) + + def _on_membership_or_invite_restricted(self, event: EventBase) -> bool: + """Implements the checks and behaviour specified for the "restricted" rule. + + "restricted" currently means that users can only invite users if their server is + included in a limited list of domains. + + Args: + event: The event to check. + + Returns: + True if the event can be allowed, False otherwise. + """ + # We're not applying the rules on m.room.third_party_member events here because + # the filtering on threepids is done in check_threepid_can_be_invited, which is + # called before check_event_allowed. + if event.type == EventTypes.ThirdPartyInvite: + return True + + # We only need to process "join" and "invite" memberships, in order to be backward + # compatible, e.g. if a user from a blacklisted server joined a restricted room + # before the rules started being enforced on the server, that user must be able to + # leave it. + if event.membership not in [Membership.JOIN, Membership.INVITE]: + return True + + invitee_domain = get_domain_from_id(event.state_key) + return invitee_domain not in self.domains_forbidden_when_restricted + + def _on_membership_or_invite_unrestricted( + self, event: EventBase, state_events: StateMap[EventBase] + ) -> bool: + """Implements the checks and behaviour specified for the "unrestricted" rule. + + "unrestricted" currently means that forbidden users cannot join without an invite. + + Returns: + True if the event can be allowed, False otherwise. + """ + # If this is a join from a forbidden user and they don't have an invite to the + # room, then deny it + if event.type == EventTypes.Member and event.membership == Membership.JOIN: + # Check if this user is from a forbidden server + target_domain = get_domain_from_id(event.state_key) + if target_domain in self.domains_forbidden_when_restricted: + # If so, they'll need an invite to join this room. Check if one exists + if not self._user_is_invited_to_room(event.state_key, state_events): + return False + + return True + + def _on_membership_or_invite_direct( + self, event: EventBase, state_events: StateMap[EventBase], + ) -> bool: + """Implements the checks and behaviour specified for the "direct" rule. + + "direct" currently means that no member is allowed apart from the two initial + members the room was created for (i.e. the room's creator and their first invitee). + + Args: + event: The event to check. + state_events: A dict mapping (event type, state key) to state event. + The state of the room before the event was sent. + + Returns: + True if the event can be allowed, False otherwise. + """ + # Get the room memberships and 3PID invite tokens from the room's state. + existing_members, threepid_tokens = self._get_members_and_tokens_from_state( + state_events + ) + + # There should never be more than one 3PID invite in the room state: if the second + # original user came and left, and we're inviting them using their email address, + # given we know they have a Matrix account binded to the address (so they could + # join the first time), Synapse will successfully look it up before attempting to + # store an invite on the IS. + if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite: + # If we already have a 3PID invite in flight, don't accept another one, unless + # the new one has the same invite token as its state key. This is because 3PID + # invite revocations must be allowed, and a revocation is basically a new 3PID + # invite event with an empty content and the same token as the invite it + # revokes. + return event.state_key in threepid_tokens + + if len(existing_members) == 2: + # If the user was within the two initial user of the room, Synapse would have + # looked it up successfully and thus sent a m.room.member here instead of + # m.room.third_party_invite. + if event.type == EventTypes.ThirdPartyInvite: + return False + + # We can only have m.room.member events here. The rule in this case is to only + # allow the event if its target is one of the initial two members in the room, + # i.e. the state key of one of the two m.room.member states in the room. + return event.state_key in existing_members + + # We're alone in the room (and always have been) and there's one 3PID invite in + # flight. + if len(existing_members) == 1 and len(threepid_tokens) == 1: + # We can only have m.room.member events here. In this case, we can only allow + # the event if it's either a m.room.member from the joined user (we can assume + # that the only m.room.member event is a join otherwise we wouldn't be able to + # send an event to the room) or an an invite event which target is the invited + # user. + target = event.state_key + is_from_threepid_invite = self._is_invite_from_threepid( + event, threepid_tokens[0] + ) + return is_from_threepid_invite or target == existing_members[0] + + return True + + def _is_power_level_content_allowed( + self, + content: Dict, + access_rule: str, + default_power_levels: Optional[Dict] = None, + on_room_creation: bool = True, + ) -> bool: + """Check if a given power levels event is permitted under the given access rule. + + It shouldn't be allowed if it either changes the default PL to a non-0 value or + gives a non-0 PL to a user that would have been forbidden from joining the room + under a more restrictive access rule. + + Args: + content: The content of the m.room.power_levels event to check. + access_rule: The access rule in place in this room. + default_power_levels: The default power levels when a room is created with + the specified access rule. Required if on_room_creation is True. + on_room_creation: True if this call is happening during a room's + creation, False otherwise. + + Returns: + Whether the content of the power levels event is valid. + """ + # Only enforce these rules during room creation + # + # We want to allow admins to modify or fix the power levels in a room if they + # have a special circumstance, but still want to encourage a certain pattern during + # room creation. + if on_room_creation: + # We specifically don't fail if "invite" or "state_default" are None, as those + # values should be replaced with our "default" power level values anyways, + # which are compliant + + invite = default_power_levels["invite"] + state_default = default_power_levels["state_default"] + + # If invite requirements are less than our required defaults + if content.get("invite", invite) < invite: + return False + + # If "other" state requirements are less than our required defaults + if content.get("state_default", state_default) < state_default: + return False + + # Check if we need to apply the restrictions with the current rule. + if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS: + return True + + # If users_default is explicitly set to a non-0 value, deny the event. + users_default = content.get("users_default", 0) + if users_default: + return False + + users = content.get("users", {}) + for user_id, power_level in users.items(): + server_name = get_domain_from_id(user_id) + # Check the domain against the blacklist. If found, and the PL isn't 0, deny + # the event. + if ( + server_name in self.domains_forbidden_when_restricted + and power_level != 0 + ): + return False + + return True + + def _on_join_rule_change(self, event: EventBase, rule: str) -> bool: + """Check whether a join rule change is allowed. + + A join rule change is always allowed unless the new join rule is "public" and + the current access rule is "direct". + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + Whether the change is allowed. + """ + if event.content.get("join_rule") == JoinRules.PUBLIC: + return rule != AccessRules.DIRECT + + return True + + def _on_room_avatar_change(self, event: EventBase, rule: str) -> bool: + """Check whether a change of room avatar is allowed. + The current rule is to forbid such a change in direct chats but allow it + everywhere else. + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + True if the event can be allowed, False otherwise. + """ + return rule != AccessRules.DIRECT + + def _on_room_name_change(self, event: EventBase, rule: str) -> bool: + """Check whether a change of room name is allowed. + The current rule is to forbid such a change in direct chats but allow it + everywhere else. + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + True if the event can be allowed, False otherwise. + """ + return rule != AccessRules.DIRECT + + def _on_room_topic_change(self, event: EventBase, rule: str) -> bool: + """Check whether a change of room topic is allowed. + The current rule is to forbid such a change in direct chats but allow it + everywhere else. + + Args: + event: The event to check. + rule: The name of the rule to apply. + + Returns: + True if the event can be allowed, False otherwise. + """ + return rule != AccessRules.DIRECT + + @staticmethod + def _get_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]: + """Extract the rule to be applied from the given set of state events. + + Args: + state_events: A dict mapping (event type, state key) to state event. + + Returns: + The name of the rule (either "direct", "restricted" or "unrestricted") if found, + else None. + """ + access_rules = state_events.get((ACCESS_RULES_TYPE, "")) + if access_rules is None: + return AccessRules.RESTRICTED + + return access_rules.content.get("rule") + + @staticmethod + def _get_join_rule_from_state(state_events: StateMap[EventBase]) -> Optional[str]: + """Extract the room's join rule from the given set of state events. + + Args: + state_events (dict[tuple[event type, state key], EventBase]): The set of state + events. + + Returns: + The name of the join rule (either "public", or "invite") if found, else None. + """ + join_rule_event = state_events.get((EventTypes.JoinRules, "")) + if join_rule_event is None: + return None + + return join_rule_event.content.get("join_rule") + + @staticmethod + def _get_members_and_tokens_from_state( + state_events: StateMap[EventBase], + ) -> Tuple[List[str], List[str]]: + """Retrieves the list of users that have a m.room.member event in the room, + as well as 3PID invites tokens in the room. + + Args: + state_events: A dict mapping (event type, state key) to state event. + + Returns: + A tuple containing the: + * targets of the m.room.member events in the state. + * 3PID invite tokens in the state. + """ + existing_members = [] + threepid_invite_tokens = [] + for key, state_event in state_events.items(): + if key[0] == EventTypes.Member and state_event.content: + existing_members.append(state_event.state_key) + if key[0] == EventTypes.ThirdPartyInvite and state_event.content: + # Don't include revoked invites. + threepid_invite_tokens.append(state_event.state_key) + + return existing_members, threepid_invite_tokens + + @staticmethod + def _is_invite_from_threepid(invite: EventBase, threepid_invite_token: str) -> bool: + """Checks whether the given invite follows the given 3PID invite. + + Args: + invite: The m.room.member event with "invite" membership. + threepid_invite_token: The state key from the 3PID invite. + + Returns: + Whether the invite is due to the given 3PID invite. + """ + token = ( + invite.content.get("third_party_invite", {}) + .get("signed", {}) + .get("token", "") + ) + + return token == threepid_invite_token + + def _is_local_user(self, user_id: str) -> bool: + """Checks whether a given user ID belongs to this homeserver, or a remote + + Args: + user_id: A user ID to check. + + Returns: + True if the user belongs to this homeserver, False otherwise. + """ + user = UserID.from_string(user_id) + + # Extract the localpart and ask the module API for a user ID from the localpart + # The module API will append the local homeserver's server_name + local_user_id = self.module_api.get_qualified_user_id(user.localpart) + + # If the user ID we get based on the localpart is the same as the original user ID, + # then they were a local user + return user_id == local_user_id + + def _user_is_invited_to_room( + self, user_id: str, state_events: StateMap[EventBase] + ) -> bool: + """Checks whether a given user has been invited to a room + + A user has an invite for a room if its state contains a `m.room.member` + event with membership "invite" and their user ID as the state key. + + Args: + user_id: The user to check. + state_events: The state events from the room. + + Returns: + True if the user has been invited to the room, or False if they haven't. + """ + for (event_type, state_key), state_event in state_events.items(): + if ( + event_type == EventTypes.Member + and state_key == user_id + and state_event.membership == Membership.INVITE + ): + return True + + return False diff --git a/synapse/types.py b/synapse/types.py
index eafe729dfe..b629976853 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -34,6 +34,7 @@ from typing import ( import attr from signedjson.key import decode_verify_key_bytes +from six.moves import filter from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError @@ -335,6 +336,19 @@ def contains_invalid_mxid_characters(localpart: str) -> bool: return any(c not in mxid_localpart_allowed_characters for c in localpart) +def strip_invalid_mxid_characters(localpart): + """Removes any invalid characters from an mxid + + Args: + localpart (basestring): the localpart to be stripped + + Returns: + localpart (basestring): the localpart having been stripped + """ + filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart) + return "".join(filtered) + + UPPER_CASE_PATTERN = re.compile(b"[A-Z_]") # the following is a pattern which matches '=', and bytes which are not allowed in a mxid @@ -469,8 +483,7 @@ class RoomStreamToken: ) def __attrs_post_init__(self): - """Validates that both `topological` and `instance_map` aren't set. - """ + """Validates that both `topological` and `instance_map` aren't set.""" if self.instance_map and self.topological: raise ValueError( @@ -498,7 +511,11 @@ class RoomStreamToken: instance_name = await store.get_name_from_instance_id(instance_id) instance_map[instance_name] = pos - return cls(topological=None, stream=stream, instance_map=instance_map,) + return cls( + topological=None, + stream=stream, + instance_map=instance_map, + ) except Exception: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -675,7 +692,7 @@ class PersistedEventPosition: persisted in the same room after this position will be after the returned `RoomStreamToken`. - Note: no guarentees are made about ordering w.r.t. events in other + Note: no guarantees are made about ordering w.r.t. events in other rooms. """ # Doing the naive thing satisfies the desired properties described in diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 9a873c8e8e..719e35b78d 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -252,8 +252,7 @@ class Linearizer: self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry] def is_queued(self, key: Hashable) -> bool: - """Checks whether there is a process queued up waiting - """ + """Checks whether there is a process queued up waiting""" entry = self.key_to_defer.get(key) if not entry: # No entry so nothing is waiting. @@ -452,7 +451,9 @@ R = TypeVar("R") def timeout_deferred( - deferred: defer.Deferred, timeout: float, reactor: IReactorTime, + deferred: defer.Deferred, + timeout: float, + reactor: IReactorTime, ) -> defer.Deferred: """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -497,7 +498,7 @@ def timeout_deferred( delayed_call = reactor.callLater(timeout, time_it_out) def convert_cancelled(value: failure.Failure): - # if the orgininal deferred was cancelled, and our timeout has fired, then + # if the original deferred was cancelled, and our timeout has fired, then # the reason it was cancelled was due to our timeout. Turn the CancelledError # into a TimeoutError. if timed_out[0] and value.check(CancelledError): @@ -529,8 +530,7 @@ def timeout_deferred( @attr.s(slots=True, frozen=True) class DoneAwaitable: - """Simple awaitable that returns the provided value. - """ + """Simple awaitable that returns the provided value.""" value = attr.ib() @@ -545,8 +545,7 @@ class DoneAwaitable: def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: - """Convert a value to an awaitable if not already an awaitable. - """ + """Convert a value to an awaitable if not already an awaitable.""" if inspect.isawaitable(value): assert isinstance(value, Awaitable) return value diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 89f0b38535..e676c2cac4 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py
@@ -149,8 +149,7 @@ KNOWN_KEYS = { def intern_string(string): - """Takes a (potentially) unicode string and interns it if it's ascii - """ + """Takes a (potentially) unicode string and interns it if it's ascii""" if string is None: return None @@ -161,8 +160,7 @@ def intern_string(string): def intern_dict(dictionary): - """Takes a dictionary and interns well known keys and their values - """ + """Takes a dictionary and interns well known keys and their values""" return { KNOWN_KEYS.get(key, key): _intern_known_values(key, value) for key, value in dictionary.items() diff --git a/synapse/util/caches/cached_call.py b/synapse/util/caches/cached_call.py new file mode 100644
index 0000000000..3ee0f2317a --- /dev/null +++ b/synapse/util/caches/cached_call.py
@@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union + +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure + +from synapse.logging.context import make_deferred_yieldable, run_in_background + +TV = TypeVar("TV") + + +class CachedCall(Generic[TV]): + """A wrapper for asynchronous calls whose results should be shared + + This is useful for wrapping asynchronous functions, where there might be multiple + callers, but we only want to call the underlying function once (and have the result + returned to all callers). + + Similar results can be achieved via a lock of some form, but that typically requires + more boilerplate (and ends up being less efficient). + + Correctly handles Synapse logcontexts (logs and resource usage for the underlying + function are logged against the logcontext which is active when get() is first + called). + + Example usage: + + _cached_val = CachedCall(_load_prop) + + async def handle_request() -> X: + # We can call this multiple times, but it will result in a single call to + # _load_prop(). + return await _cached_val.get() + + async def _load_prop() -> X: + await difficult_operation() + + + The implementation is deliberately single-shot (ie, once the call is initiated, + there is no way to ask for it to be run). This keeps the implementation and + semantics simple. If you want to make a new call, simply replace the whole + CachedCall object. + """ + + __slots__ = ["_callable", "_deferred", "_result"] + + def __init__(self, f: Callable[[], Awaitable[TV]]): + """ + Args: + f: The underlying function. Only one call to this function will be alive + at once (per instance of CachedCall) + """ + self._callable = f # type: Optional[Callable[[], Awaitable[TV]]] + self._deferred = None # type: Optional[Deferred] + self._result = None # type: Union[None, Failure, TV] + + async def get(self) -> TV: + """Kick off the call if necessary, and return the result""" + + # Fire off the callable now if this is our first time + if not self._deferred: + self._deferred = run_in_background(self._callable) + + # we will never need the callable again, so make sure it can be GCed + self._callable = None + + # once the deferred completes, store the result. We cannot simply leave the + # result in the deferred, since if it's a Failure, GCing the deferred + # would then log a critical error about unhandled Failures. + def got_result(r): + self._result = r + + self._deferred.addBoth(got_result) + + # TODO: consider cancellation semantics. Currently, if the call to get() + # is cancelled, the underlying call will continue (and any future calls + # will get the result/exception), which I think is *probably* ok, modulo + # the fact the underlying call may be logged to a cancelled logcontext, + # and any eventual exception may not be reported. + + # we can now await the deferred, and once it completes, return the result. + await make_deferred_yieldable(self._deferred) + + # I *think* this is the easiest way to correctly raise a Failure without having + # to gut-wrench into the implementation of Deferred. + d = Deferred() + d.callback(self._result) + return await d + + +class RetryOnExceptionCachedCall(Generic[TV]): + """A wrapper around CachedCall which will retry the call if an exception is thrown + + This is used in much the same way as CachedCall, but adds some extra functionality + so that if the underlying function throws an exception, then the next call to get() + will initiate another call to the underlying function. (Any calls to get() which + are already pending will raise the exception.) + """ + + slots = ["_cachedcall"] + + def __init__(self, f: Callable[[], Awaitable[TV]]): + async def _wrapper() -> TV: + try: + return await f() + except Exception: + # the call raised an exception: replace the underlying CachedCall to + # trigger another call next time get() is called + self._cachedcall = CachedCall(_wrapper) + raise + + self._cachedcall = CachedCall(_wrapper) + + async def get(self) -> TV: + return await self._cachedcall.get() diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index a924140cdf..4e84379914 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py
@@ -122,7 +122,8 @@ class _LruCachedFunction(Generic[F]): def lru_cache( - max_entries: int = 1000, cache_context: bool = False, + max_entries: int = 1000, + cache_context: bool = False, ) -> Callable[[F], _LruCachedFunction[F]]: """A method decorator that applies a memoizing cache around the function. @@ -156,7 +157,9 @@ def lru_cache( def func(orig: F) -> _LruCachedFunction[F]: desc = LruCacheDescriptor( - orig, max_entries=max_entries, cache_context=cache_context, + orig, + max_entries=max_entries, + cache_context=cache_context, ) return cast(_LruCachedFunction[F], desc) @@ -170,14 +173,18 @@ class LruCacheDescriptor(_CacheDescriptorBase): sentinel = object() def __init__( - self, orig, max_entries: int = 1000, cache_context: bool = False, + self, + orig, + max_entries: int = 1000, + cache_context: bool = False, ): super().__init__(orig, num_args=None, cache_context=cache_context) self.max_entries = max_entries def __get__(self, obj, owner): cache = LruCache( - cache_name=self.orig.__name__, max_size=self.max_entries, + cache_name=self.orig.__name__, + max_size=self.max_entries, ) # type: LruCache[CacheKey, Any] get_cache_key = self.cache_key_builder @@ -212,7 +219,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): class DeferredCacheDescriptor(_CacheDescriptorBase): - """ A method decorator that applies a memoizing cache around the function. + """A method decorator that applies a memoizing cache around the function. This caches deferreds, rather than the results themselves. Deferreds that fail are removed from the cache. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index c541bf4579..644e9e778a 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py
@@ -84,8 +84,7 @@ class StreamChangeCache: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos - """ + """Returns True if the entity may have been updated since stream_pos""" assert isinstance(stream_pos, int) if stream_pos < self._earliest_known_stream_pos: @@ -133,8 +132,7 @@ class StreamChangeCache: return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed - """ + """Returns if any entity has changed""" assert type(stream_pos) is int if not self._cache: diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index a6ee9edaec..3c47285d05 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py
@@ -108,7 +108,10 @@ class Signal: return await maybe_awaitable(observer(*args, **kwargs)) except Exception as e: logger.warning( - "%s signal observer %s failed: %r", self.name, observer, e, + "%s signal observer %s failed: %r", + self.name, + observer, + e, ) deferreds = [run_in_background(do, o) for o in self.observers] diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 733f5e26e6..68dc632491 100644 --- a/synapse/util/file_consumer.py +++ b/synapse/util/file_consumer.py
@@ -83,15 +83,13 @@ class BackgroundFileConsumer: self._producer.resumeProducing() def unregisterProducer(self): - """Part of IProducer interface - """ + """Part of IProducer interface""" self._producer = None if not self._finished_deferred.called: self._bytes_queue.put_nowait(None) def write(self, bytes): - """Part of IProducer interface - """ + """Part of IProducer interface""" if self._write_exception: raise self._write_exception @@ -107,8 +105,7 @@ class BackgroundFileConsumer: self._producer.pauseProducing() def _writer(self): - """This is run in a background thread to write to the file. - """ + """This is run in a background thread to write to the file.""" try: while self._producer or not self._bytes_queue.empty(): # If we've paused the producer check if we should resume the @@ -135,13 +132,11 @@ class BackgroundFileConsumer: self._file_obj.close() def wait(self): - """Returns a deferred that resolves when finished writing to file - """ + """Returns a deferred that resolves when finished writing to file""" return make_deferred_yieldable(self._finished_deferred) def _resume_paused_producer(self): - """Gets called if we should resume producing after being paused - """ + """Gets called if we should resume producing after being paused""" if self._paused_producer and self._producer: self._paused_producer = False self._producer.resumeProducing() diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index 8d2411513f..98707c119d 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py
@@ -62,7 +62,8 @@ def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]: def sorted_topologically( - nodes: Iterable[T], graph: Mapping[T, Collection[T]], + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], ) -> Generator[T, None, None]: """Given a set of nodes and a graph, yield the nodes in toplogical order. diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index 50516926f3..e3a8ed5b2f 100644 --- a/synapse/util/jsonobject.py +++ b/synapse/util/jsonobject.py
@@ -15,7 +15,7 @@ class JsonEncodedObject: - """ A common base class for defining protocol units that are represented + """A common base class for defining protocol units that are represented as JSON. Attributes: @@ -39,7 +39,7 @@ class JsonEncodedObject: """ def __init__(self, **kwargs): - """ Takes the dict of `kwargs` and loads all keys that are *valid* + """Takes the dict of `kwargs` and loads all keys that are *valid* (i.e., are included in the `valid_keys` list) into the dictionary` instance variable. @@ -61,7 +61,7 @@ class JsonEncodedObject: self.unrecognized_keys[k] = v def get_dict(self): - """ Converts this protocol unit into a :py:class:`dict`, ready to be + """Converts this protocol unit into a :py:class:`dict`, ready to be encoded as JSON. The keys it encodes are: `valid_keys` - `internal_keys` diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index f4de6b9f54..1023c856d1 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py
@@ -161,8 +161,7 @@ class Measure: return self._logging_context.get_resource_usage() def _update_in_flight(self, metrics): - """Gets called when processing in flight metrics - """ + """Gets called when processing in flight metrics""" duration = self.clock.time() - self.start metrics.real_time_max = max(metrics.real_time_max, duration) diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 09b094ded7..d184e2a90c 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py
@@ -25,7 +25,7 @@ from synapse.config._util import json_error_to_config_error def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: - """ Loads a synapse module with its config + """Loads a synapse module with its config Args: provider: a dict with keys 'module' (the module name) and 'config' diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 72574d3af2..d9f9ae99d6 100644 --- a/synapse/util/patch_inline_callbacks.py +++ b/synapse/util/patch_inline_callbacks.py
@@ -204,16 +204,13 @@ def _check_yield_points(f: Callable, changes: List[str]): # We don't raise here as its perfectly valid for contexts to # change in a function, as long as it sets the correct context # on resolving (which is checked separately). - err = ( - "%s changed context from %s to %s, happened between lines %d and %d in %s" - % ( - frame.f_code.co_name, - expected_context, - current_context(), - last_yield_line_no, - frame.f_lineno, - frame.f_code.co_filename, - ) + err = "%s changed context from %s to %s, happened between lines %d and %d in %s" % ( + frame.f_code.co_name, + expected_context, + current_context(), + last_yield_line_no, + frame.f_lineno, + frame.f_code.co_filename, ) changes.append(err) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f8038bf861..9ce7873ab5 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken -client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$") # https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, # together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically @@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") rand = random.SystemRandom() -def random_string(length): +def random_string(length: int) -> str: return "".join(rand.choice(string.ascii_letters) for _ in range(length)) -def random_string_with_symbols(length): +def random_string_with_symbols(length: int) -> str: return "".join(rand.choice(_string_with_symbols) for _ in range(length)) -def is_ascii(s): - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True +def is_ascii(s: bytes) -> bool: + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False + return True -def assert_valid_client_secret(client_secret): - """Validate that a given string matches the client_secret regex defined by the spec""" - if client_secret_regex.match(client_secret) is None: +def assert_valid_client_secret(client_secret: str) -> None: + """Validate that a given string matches the client_secret defined by the spec""" + if ( + len(client_secret) <= 0 + or len(client_secret) > 255 + or CLIENT_SECRET_REGEX.match(client_secret) is None + ): raise SynapseError( 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM ) diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 43c2e0ac23..63f955acff 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py
@@ -19,8 +19,8 @@ import re logger = logging.getLogger(__name__) -def check_3pid_allowed(hs, medium, address): - """Checks whether a given format of 3PID is allowed to be used on this HS +async def check_3pid_allowed(hs, medium, address): + """Checks whether a given 3PID is allowed to be used on this HS Args: hs (synapse.server.HomeServer): server @@ -31,6 +31,33 @@ def check_3pid_allowed(hs, medium, address): bool: whether the 3PID medium/address is allowed to be added to this HS """ + if hs.config.check_is_for_allowed_local_3pids: + data = await hs.get_simple_http_client().get_json( + "https://%s%s" + % ( + hs.config.check_is_for_allowed_local_3pids, + "/_matrix/identity/api/v1/internal-info", + ), + {"medium": medium, "address": address}, + ) + + # Check for invalid response + if "hs" not in data and "shadow_hs" not in data: + return False + + # Check if this user is intended to register for this homeserver + if ( + data.get("hs") != hs.config.server_name + and data.get("shadow_hs") != hs.config.server_name + ): + return False + + if data.get("requires_invite", False) and not data.get("invited", False): + # Requires an invite but hasn't been invited + return False + + return True + if hs.config.allowed_local_3pids: for constraint in hs.config.allowed_local_3pids: logger.debug( diff --git a/synapse/visibility.py b/synapse/visibility.py
index ec50e7e977..e39d02602a 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -80,6 +80,7 @@ async def filter_events_for_client( events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), @@ -233,7 +234,7 @@ async def filter_events_for_client( elif visibility == HistoryVisibility.SHARED and is_peeking: # if the visibility is shared, users cannot see the event unless - # they have *subequently* joined the room (or were members at the + # they have *subsequently* joined the room (or were members at the # time, of course) # # XXX: if the user has subsequently joined and then left again,