summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/__init__.py1
-rw-r--r--synapse/rest/admin/rooms.py4
-rw-r--r--synapse/rest/admin/users.py20
-rw-r--r--synapse/rest/client/v1/login.py17
-rw-r--r--synapse/rest/client/v1/presence.py8
-rw-r--r--synapse/rest/client/v1/profile.py35
-rw-r--r--synapse/rest/client/v1/room.py6
-rw-r--r--synapse/rest/client/v2_alpha/account.py212
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py7
-rw-r--r--synapse/rest/client/v2_alpha/register.py246
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py10
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py143
-rw-r--r--synapse/rest/client/versions.py5
-rw-r--r--synapse/rest/consent/consent_resource.py5
-rw-r--r--synapse/rest/media/v1/_base.py3
-rw-r--r--synapse/rest/media/v1/media_repository.py12
-rw-r--r--synapse/rest/media/v1/media_storage.py6
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py27
18 files changed, 643 insertions, 124 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py

index 46e458e95b..2e81eeff65 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -118,6 +118,7 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + password_policy.register_servlets(hs, client_resource) relations.register_servlets(hs, client_resource) password_policy.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 8173baef8f..e07c32118d 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -15,7 +15,7 @@ import logging from typing import List, Optional -from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.servlet import ( RestServlet, @@ -77,7 +77,7 @@ class ShutdownRoomRestServlet(RestServlet): info, stream_id = await self._room_creation_handler.create_room( room_creator_requester, config={ - "preset": "public_chat", + "preset": RoomCreationPreset.PUBLIC_CHAT, "name": room_name, "power_level_content_override": {"users_default": -10}, }, diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index fefc8f71fa..e4330c39d6 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -16,9 +16,7 @@ import hashlib import hmac import logging import re - -from six import text_type -from six.moves import http_client +from http import HTTPStatus from synapse.api.constants import UserTypes from synapse.api.errors import Codes, NotFoundError, SynapseError @@ -215,10 +213,7 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if "password" in body: - if ( - not isinstance(body["password"], text_type) - or len(body["password"]) > 512 - ): + if not isinstance(body["password"], str) or len(body["password"]) > 512: raise SynapseError(400, "Invalid password") else: new_password = body["password"] @@ -252,7 +247,7 @@ class UserRestServletV2(RestServlet): password = body.get("password") password_hash = None if password is not None: - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_hash = await self.auth_handler.hash(password) @@ -370,10 +365,7 @@ class UserRegisterServlet(RestServlet): 400, "username must be specified", errcode=Codes.BAD_JSON ) else: - if ( - not isinstance(body["username"], text_type) - or len(body["username"]) > 512 - ): + if not isinstance(body["username"], str) or len(body["username"]) > 512: raise SynapseError(400, "Invalid username") username = body["username"].encode("utf-8") @@ -386,7 +378,7 @@ class UserRegisterServlet(RestServlet): ) else: password = body["password"] - if not isinstance(password, text_type) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") password_bytes = password.encode("utf-8") @@ -477,7 +469,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index dceb2792fa..bf0f9bd077 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -60,10 +60,18 @@ def login_id_thirdparty_from_phone(identifier): Returns: Login identifier dict of type 'm.id.threepid' """ - if "country" not in identifier or "number" not in identifier: + if "country" not in identifier or ( + # The specification requires a "phone" field, while Synapse used to require a "number" + # field. Accept both for backwards compatibility. + "phone" not in identifier + and "number" not in identifier + ): raise SynapseError(400, "Invalid phone-type identifier") - msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) + # Accept both "phone" and "number" as valid keys in m.id.phone + phone_number = identifier.get("phone", identifier["number"]) + + msisdn = phone_number_to_msisdn(identifier["country"], phone_number) return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} @@ -73,7 +81,8 @@ class LoginRestServlet(RestServlet): CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" - JWT_TYPE = "m.login.jwt" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__() @@ -108,6 +117,7 @@ class LoginRestServlet(RestServlet): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it @@ -141,6 +151,7 @@ class LoginRestServlet(RestServlet): try: if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): result = await self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eec16f8ad8..970fdd5834 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py
@@ -17,8 +17,6 @@ """ import logging -from six import string_types - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -51,7 +49,9 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state(state, self.clock.time_msec()) + state = format_user_presence_state( + state, self.clock.time_msec(), include_user_id=False + ) return 200, state @@ -71,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet): if "status_msg" in content: state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], string_types): + if not isinstance(state["status_msg"], str): raise SynapseError(400, "status_msg must be a string.") if content: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@ # limitations under the License. """ This module contains REST servlets to do with profile: /profile/<paths> """ +from twisted.internet import defer from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet): super(ProfileDisplaynameRestServlet, self).__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() + self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet): await self.profile_handler.set_displayname(user, requester, new_name, is_admin) + if self.hs.config.shadow_server: + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) + self.shadow_displayname(shadow_user.to_string(), content) + return 200, {} def on_OPTIONS(self, request, user_id): return 200, {} + @defer.inlineCallbacks + def shadow_displayname(self, user_id, body): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.put_json( + "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, + ) + class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) @@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet): super(ProfileAvatarURLRestServlet, self).__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() + self.http_client = hs.get_simple_http_client() self.auth = hs.get_auth() async def on_GET(self, request, user_id): @@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet): user, requester, new_avatar_url, is_admin ) + if self.hs.config.shadow_server: + shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs")) + self.shadow_avatar_url(shadow_user.to_string(), content) + return 200, {} def on_OPTIONS(self, request, user_id): return 200, {} + @defer.inlineCallbacks + def shadow_avatar_url(self, user_id, body): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.put_json( + "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" + % (shadow_hs_url, user_id, as_token, user_id), + body, + ) + class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 105e0cf4d2..bbf98a5cfe 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -18,8 +18,7 @@ import logging import re from typing import List, Optional - -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse from canonicaljson import json @@ -726,7 +725,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): content["id_server"], requester, txn_id, - content.get("id_access_token"), + new_room=False, + id_access_token=content.get("id_access_token"), ) return 200, {} diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b58a77826f..aeaf322985 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re -from six.moves import http_client +from http import HTTPStatus + +from twisted.internet import defer from synapse.api.constants import LoginType from synapse.api.errors import Codes, SynapseError, ThreepidValidationError @@ -29,6 +32,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import check_3pid_allowed @@ -91,7 +95,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if not check_3pid_allowed(self.hs, "email", email): raise SynapseError( 403, - "Your email domain is not authorized on this server", + "Your email is not authorized on this server", Codes.THREEPID_DENIED, ) @@ -221,6 +225,7 @@ class PasswordRestServlet(RestServlet): self.datastore = self.hs.get_datastore() self.password_policy_handler = hs.get_password_policy_handler() self._set_password_handler = hs.get_set_password_handler() + self.http_client = hs.get_simple_http_client() @interactive_auth_handler async def on_POST(self, request): @@ -252,13 +257,17 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = await self.auth.get_user_by_req(request) - params = await self.auth_handler.validate_user_via_ui_auth( - requester, - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + # blindly trust ASes without UI-authing them + if requester.app_service: + params = body + else: + params = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) user_id = requester.user.to_string() else: requester = None @@ -298,11 +307,29 @@ class PasswordRestServlet(RestServlet): user_id, new_password_hash, logout_devices, requester ) + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + self.shadow_password(params, shadow_user.to_string()) + return 200, {} def on_OPTIONS(self, _): return 200, {} + @defer.inlineCallbacks + def shadow_password(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") @@ -320,7 +347,7 @@ class DeactivateAccountRestServlet(RestServlet): erase = body.get("erase", False) if not isinstance(erase, bool): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'erase' must be a boolean, if given", Codes.BAD_JSON, ) @@ -397,13 +424,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not (await check_3pid_allowed(self.hs, "email", email)): raise SynapseError( 403, - "Your email domain is not authorized on this server", + "Your email is not authorized on this server", Codes.THREEPID_DENIED, ) + assert_valid_client_secret(body["client_secret"]) + existing_user_id = await self.store.get_user_id_by_threepid( "email", body["email"] ) @@ -467,13 +496,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( 403, "Account phone numbers are not authorized on this server", Codes.THREEPID_DENIED, ) + assert_valid_client_secret(body["client_secret"]) + existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: @@ -632,7 +663,8 @@ class ThreepidRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() - self.datastore = self.hs.get_datastore() + self.datastore = hs.get_datastore() + self.http_client = hs.get_simple_http_client() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) @@ -651,6 +683,29 @@ class ThreepidRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) + # skip validation if this is a shadow 3PID from an AS + if requester.app_service: + # XXX: ASes pass in a validated threepid directly to bypass the IS. + # This makes the API entirely change shape when we have an AS token; + # it really should be an entirely separate API - perhaps + # /account/3pid/replicate or something. + threepid = body.get("threepid") + + await self.auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) + + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + + return 200, {} + threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds") if threepid_creds is None: raise SynapseError( @@ -672,12 +727,36 @@ class ThreepidRestServlet(RestServlet): validation_session["address"], validation_session["validated_at"], ) + + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + threepid = { + "medium": validation_session["medium"], + "address": validation_session["address"], + "validated_at": validation_session["validated_at"], + } + self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + @defer.inlineCallbacks + def shadow_3pid(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") @@ -688,6 +767,7 @@ class ThreepidAddRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.http_client = hs.get_simple_http_client() @interactive_auth_handler async def on_POST(self, request): @@ -723,12 +803,34 @@ class ThreepidAddRestServlet(RestServlet): validation_session["address"], validation_session["validated_at"], ) + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + threepid = { + "medium": validation_session["medium"], + "address": validation_session["address"], + "validated_at": validation_session["validated_at"], + } + self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) + @defer.inlineCallbacks + def shadow_3pid(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") @@ -798,6 +900,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() + self.http_client = hs.get_simple_http_client() async def on_POST(self, request): if not self.hs.config.enable_3pid_changes: @@ -822,6 +925,12 @@ class ThreepidDeleteRestServlet(RestServlet): logger.exception("Failed to remove threepid") raise SynapseError(500, "Failed to remove threepid") + if self.hs.config.shadow_server: + shadow_user = UserID( + requester.user.localpart, self.hs.config.shadow_server.get("hs") + ) + self.shadow_3pid_delete(body, shadow_user.to_string()) + if ret: id_server_unbind_result = "success" else: @@ -829,6 +938,77 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} + @defer.inlineCallbacks + def shadow_3pid_delete(self, body, user_id): + # TODO: retries + shadow_hs_url = self.hs.config.shadow_server.get("hs_url") + as_token = self.hs.config.shadow_server.get("as_token") + + yield self.http_client.post_json_get_json( + "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" + % (shadow_hs_url, as_token, user_id), + body, + ) + + +class ThreepidLookupRestServlet(RestServlet): + PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")] + + def __init__(self, hs): + super(ThreepidLookupRestServlet, self).__init__() + self.auth = hs.get_auth() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_GET(self, request): + """Proxy a /_matrix/identity/api/v1/lookup request to an identity + server + """ + yield self.auth.get_user_by_req(request) + + # Verify query parameters + query_params = request.args + assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"]) + + # Retrieve needed information from query parameters + medium = parse_string(request, "medium") + address = parse_string(request, "address") + id_server = parse_string(request, "id_server") + + # Proxy the request to the identity server. lookup_3pid handles checking + # if the lookup is allowed so we don't need to do it here. + ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address) + + defer.returnValue((200, ret)) + + +class ThreepidBulkLookupRestServlet(RestServlet): + PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")] + + def __init__(self, hs): + super(ThreepidBulkLookupRestServlet, self).__init__() + self.auth = hs.get_auth() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity + server + """ + yield self.auth.get_user_by_req(request) + + body = parse_json_object_from_request(request) + + assert_params_in_dict(body, ["threepids", "id_server"]) + + # Proxy the request to the identity server. lookup_3pid handles checking + # if the lookup is allowed so we don't need to do it here. + ret = yield self.identity_handler.proxy_bulk_lookup_3pid( + body["id_server"], body["threepids"] + ) + + defer.returnValue((200, ret)) + class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") @@ -857,4 +1037,6 @@ def register_servlets(hs, http_server): ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) + ThreepidLookupRestServlet(hs).register(http_server) + ThreepidBulkLookupRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..d31ec7c29d 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.types import UserID from ._base import client_patterns @@ -39,6 +40,7 @@ class AccountDataServlet(RestServlet): self.store = hs.get_datastore() self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None + self._profile_handler = hs.get_profile_handler() async def on_PUT(self, request, user_id, account_data_type): if self._is_worker: @@ -50,6 +52,11 @@ class AccountDataServlet(RestServlet): body = parse_json_object_from_request(request) + if account_data_type == "im.vector.hide_profile": + user = UserID.from_string(user_id) + hide_profile = body.get("hide_profile") + await self._profile_handler.set_active([user], not hide_profile, True) + max_id = await self.store.add_account_data_for_user( user_id, account_data_type, body ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c8d2de7b54..6255cd2e21 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -# Copyright 2015 - 2016 OpenMarket Ltd -# Copyright 2017 Vector Creations Ltd +# Copyright 2015-2016 OpenMarket Ltd +# Copyright 2017-2018 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,10 +17,9 @@ import hmac import logging +import re from typing import List, Union -from six import string_types - import synapse import synapse.api.auth import synapse.types @@ -122,10 +122,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): + if not (await check_3pid_allowed(self.hs, "email", body["email"])): raise SynapseError( 403, - "Your email domain is not authorized to register on this server", + "Your email is not authorized to register on this server", Codes.THREEPID_DENIED, ) @@ -194,7 +194,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): msisdn = phone_number_to_msisdn(country, phone_number) - if not check_3pid_allowed(self.hs, "msisdn", msisdn): + assert_valid_client_secret(body["client_secret"]) + + if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)): raise SynapseError( 403, "Phone numbers are not authorized to register on this server", @@ -350,15 +352,9 @@ class UsernameAvailabilityRestServlet(RestServlet): 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) - ip = self.hs.get_ip_from_request(request) - with self.ratelimiter.ratelimit(ip) as wait_deferred: - await wait_deferred - - username = parse_string(request, "username", required=True) - - await self.registration_handler.check_username(username) - - return 200, {"available": True} + # We are not interested in logging in via a username in this deployment. + # Simply allow anything here as it won't be used later. + return 200, {"available": True} class RegisterRestServlet(RestServlet): @@ -409,9 +405,10 @@ class RegisterRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. + desired_password_hash = None if "password" in body: password = body.pop("password") - if not isinstance(password, string_types) or len(password) > 512: + if not isinstance(password, str) or len(password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(password) @@ -420,15 +417,19 @@ class RegisterRestServlet(RestServlet): if "password_hash" in body: raise SynapseError(400, "Unexpected property: password_hash") body["password_hash"] = await self.auth_handler.hash(password) + desired_password_hash = body["password_hash"] + # We don't care about usernames for this deployment. In fact, the act + # of checking whether they exist already can leak metadata about + # which users are already registered. + # + # Usernames are already derived via the provided email. + # So, if they're not necessary, just ignore them. + # + # (we do still allow appservices to set them below) desired_username = None - if "username" in body: - if ( - not isinstance(body["username"], string_types) - or len(body["username"]) > 512 - ): - raise SynapseError(400, "Invalid username") - desired_username = body["username"] + + desired_display_name = body.get("display_name") appservice = None if self.auth.has_access_token(request): @@ -442,7 +443,7 @@ class RegisterRestServlet(RestServlet): # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. - desired_username = body.get("user", desired_username) + desired_username = body.get("user", body.get("username")) # XXX we should check that desired_username is valid. Currently # we give appservices carte blanche for any insanity in mxids, @@ -451,21 +452,16 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if isinstance(desired_username, string_types): + if isinstance(desired_username, str): result = await self._do_appservice_registration( - desired_username, access_token, body + desired_username, + desired_password_hash, + desired_display_name, + access_token, + body, ) return 200, result # we throw for non 200 responses - # for regular registration, downcase the provided username before - # attempting to register it. This should mean - # that people who try to register with upper-case in their usernames - # don't get a nasty surprise. (Note that we treat username - # case-insenstively in login, so they are free to carry on imagining - # that their username is CrAzYh4cKeR if that keeps them happy) - if desired_username is not None: - desired_username = desired_username.lower() - # == Normal User Registration == (everyone else) if not self.hs.config.enable_registration: raise SynapseError(403, "Registration has been disabled") @@ -491,13 +487,6 @@ class RegisterRestServlet(RestServlet): session_id, "registered_user_id", None ) - if desired_username is not None: - await self.registration_handler.check_username( - desired_username, - guest_access_token=guest_access_token, - assigned_user_id=registered_user_id, - ) - auth_result, params, session_id = await self.auth_handler.check_auth( self._registration_flows, request, @@ -518,7 +507,7 @@ class RegisterRestServlet(RestServlet): medium = auth_result[login_type]["medium"] address = auth_result[login_type]["address"] - if not check_3pid_allowed(self.hs, medium, address): + if not (await check_3pid_allowed(self.hs, medium, address)): raise SynapseError( 403, "Third party identifiers (email/phone numbers)" @@ -526,6 +515,80 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + existingUid = await self.store.get_user_id_by_threepid( + medium, address + ) + + if existingUid is not None: + raise SynapseError( + 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE + ) + + if self.hs.config.register_mxid_from_3pid: + # override the desired_username based on the 3PID if any. + # reset it first to avoid folks picking their own username. + desired_username = None + + # we should have an auth_result at this point if we're going to progress + # to register the user (i.e. we haven't picked up a registered_user_id + # from our session store), in which case get ready and gen the + # desired_username + if auth_result: + if ( + self.hs.config.register_mxid_from_3pid == "email" + and LoginType.EMAIL_IDENTITY in auth_result + ): + address = auth_result[LoginType.EMAIL_IDENTITY]["address"] + desired_username = synapse.types.strip_invalid_mxid_characters( + address.replace("@", "-").lower() + ) + + # find a unique mxid for the account, suffixing numbers + # if needed + while True: + try: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + # if we got this far we passed the check. + break + except SynapseError as e: + if e.errcode == Codes.USER_IN_USE: + m = re.match(r"^(.*?)(\d+)$", desired_username) + if m: + desired_username = m.group(1) + str( + int(m.group(2)) + 1 + ) + else: + desired_username += "1" + else: + # something else went wrong. + break + + if self.hs.config.register_just_use_email_for_display_name: + desired_display_name = address + else: + # Custom mapping between email address and display name + desired_display_name = _map_email_to_displayname(address) + elif ( + self.hs.config.register_mxid_from_3pid == "msisdn" + and LoginType.MSISDN in auth_result + ): + desired_username = auth_result[LoginType.MSISDN]["address"] + else: + raise SynapseError( + 400, "Cannot derive mxid from 3pid; no recognised 3pid" + ) + + if desired_username is not None: + await self.registration_handler.check_username( + desired_username, + guest_access_token=guest_access_token, + assigned_user_id=registered_user_id, + ) + if registered_user_id is not None: logger.info( "Already registered user ID %r for this session", registered_user_id @@ -536,7 +599,12 @@ class RegisterRestServlet(RestServlet): # NB: This may be from the auth handler and NOT from the POST assert_params_in_dict(params, ["password_hash"]) - desired_username = params.get("username", None) + if not self.hs.config.register_mxid_from_3pid: + desired_username = params.get("username", None) + else: + # we keep the original desired_username derived from the 3pid above + pass + guest_access_token = params.get("guest_access_token", None) new_password_hash = params.get("password_hash", None) @@ -573,6 +641,7 @@ class RegisterRestServlet(RestServlet): localpart=desired_username, password_hash=new_password_hash, guest_access_token=guest_access_token, + default_display_name=desired_display_name, threepid=threepid, address=client_addr, ) @@ -584,6 +653,14 @@ class RegisterRestServlet(RestServlet): ): await self.store.upsert_monthly_active_user(registered_user_id) + if self.hs.config.shadow_server: + await self.registration_handler.shadow_register( + localpart=desired_username, + display_name=desired_display_name, + auth_result=auth_result, + params=params, + ) + # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) await self.auth_handler.set_session_data( @@ -608,11 +685,30 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} - async def _do_appservice_registration(self, username, as_token, body): + async def _do_appservice_registration( + self, username, password_hash, display_name, as_token, body + ): + # FIXME: appservice_register() is horribly duplicated with register() + # and they should probably just be combined together with a config flag. user_id = await self.registration_handler.appservice_register( - username, as_token + username, as_token, password_hash, display_name ) - return await self._create_registration_details(user_id, body) + result = await self._create_registration_details(user_id, body) + + auth_result = body.get("auth_result") + if auth_result and LoginType.EMAIL_IDENTITY in auth_result: + threepid = auth_result[LoginType.EMAIL_IDENTITY] + await self.registration_handler.register_email_threepid( + user_id, threepid, result["access_token"] + ) + + if auth_result and LoginType.MSISDN in auth_result: + threepid = auth_result[LoginType.MSISDN] + await self.registration_handler.register_msisdn_threepid( + user_id, threepid, result["access_token"] + ) + + return result async def _create_registration_details(self, user_id, params): """Complete registration of newly-registered user @@ -663,6 +759,60 @@ class RegisterRestServlet(RestServlet): ) +def cap(name): + """Capitalise parts of a name containing different words, including those + separated by hyphens. + For example, 'John-Doe' + + Args: + name (str): The name to parse + """ + if not name: + return name + + # Split the name by whitespace then hyphens, capitalizing each part then + # joining it back together. + capatilized_name = " ".join( + "-".join(part.capitalize() for part in space_part.split("-")) + for space_part in name.split() + ) + return capatilized_name + + +def _map_email_to_displayname(address): + """Custom mapping from an email address to a user displayname + + Args: + address (str): The email address to process + Returns: + str: The new displayname + """ + # Split the part before and after the @ in the email. + # Replace all . with spaces in the first part + parts = address.replace(".", " ").split("@") + + # Figure out which org this email address belongs to + org_parts = parts[1].split(" ") + + # If this is a ...matrix.org email, mark them as an Admin + if org_parts[-2] == "matrix" and org_parts[-1] == "org": + org = "Tchap Admin" + + # Is this is a ...gouv.fr address, set the org to whatever is before + # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their + # org as "gouv" + elif org_parts[-2] == "gouv" and org_parts[-1] == "fr": + org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2] + + # Otherwise, mark their org as the email's second-level domain name + else: + org = org_parts[-2] + + desired_display_name = cap(parts[0]) + " [" + cap(org) + "]" + + return desired_display_name + + def _calculate_registration_flows( # technically `config` has to provide *all* of these interfaces, not just one config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig], diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index f067b5edac..e15927c4ea 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -14,9 +14,7 @@ # limitations under the License. import logging - -from six import string_types -from six.moves import http_client +from http import HTTPStatus from synapse.api.errors import Codes, SynapseError from synapse.http.servlet import ( @@ -47,15 +45,15 @@ class ReportEventRestServlet(RestServlet): body = parse_json_object_from_request(request) assert_params_in_dict(body, ("reason", "score")) - if not isinstance(body["reason"], string_types): + if not isinstance(body["reason"], str): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'reason' must be a string", Codes.BAD_JSON, ) if not isinstance(body["score"], int): raise SynapseError( - http_client.BAD_REQUEST, + HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", Codes.BAD_JSON, ) diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..6e8300d6a5 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@ # limitations under the License. import logging +from typing import Dict -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet, parse_json_object_from_request +from signedjson.sign import sign_json + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import ( + RestServlet, + assert_params_in_dict, + parse_json_object_from_request, +) +from synapse.types import UserID from ._base import client_patterns @@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet): self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() + self.http_client = hs.get_simple_http_client() async def on_POST(self, request): """Searches for users in directory @@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet): body = parse_json_object_from_request(request) + if self.hs.config.user_directory_defer_to_id_server: + signed_body = sign_json( + body, self.hs.hostname, self.hs.config.signing_key[0] + ) + url = "%s/_matrix/identity/api/v1/user_directory/search" % ( + self.hs.config.user_directory_defer_to_id_server, + ) + resp = await self.http_client.post_json_get_json(url, signed_body) + return 200, resp + limit = body.get("limit", 10) limit = min(limit, 50) @@ -76,5 +95,125 @@ class UserDirectorySearchRestServlet(RestServlet): return 200, results +class SingleUserInfoServlet(RestServlet): + """ + Deprecated and replaced by `/users/info` + + GET /user/{user_id}/info HTTP/1.1 + """ + + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$") + + def __init__(self, hs): + super(SingleUserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + registry = hs.get_federation_registry() + + if not registry.query_handlers.get("user_info"): + registry.register_query_handler("user_info", self._on_federation_query) + + async def on_GET(self, request, user_id): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + # Attempt to make a federation request to the server that owns this user + args = {"user_id": user_id} + res = await self.transport_layer.make_query( + user.domain, "user_info", args, retry_on_dns_fail=True + ) + return 200, res + + user_id_to_info = await self.store.get_info_for_users([user_id]) + return 200, user_id_to_info[user_id] + + async def _on_federation_query(self, args): + """Called when a request for user information appears over federation + + Args: + args (dict): Dictionary of query arguments provided by the request + + Returns: + Deferred[dict]: Deactivation and expiration information for a given user + """ + user_id = args.get("user_id") + if not user_id: + raise SynapseError(400, "user_id not provided") + + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + raise SynapseError(400, "User is not hosted on this homeserver") + + user_ids_to_info_dict = await self.store.get_info_for_users([user_id]) + return user_ids_to_info_dict[user_id] + + +class UserInfoServlet(RestServlet): + """Bulk version of `/user/{user_id}/info` endpoint + + GET /users/info HTTP/1.1 + + Returns a dictionary of user_id to info dictionary. Supports remote users + """ + + PATTERNS = client_patterns("/users/info$", unstable=True, releases=()) + + def __init__(self, hs): + super(UserInfoServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.transport_layer = hs.get_federation_transport_client() + + async def on_POST(self, request): + # Ensure the user is authenticated + await self.auth.get_user_by_req(request) + + # Extract the user_ids from the request + body = parse_json_object_from_request(request) + assert_params_in_dict(body, required=["user_ids"]) + + user_ids = body["user_ids"] + if not isinstance(user_ids, list): + raise SynapseError( + 400, + "'user_ids' must be a list of user ID strings", + errcode=Codes.INVALID_PARAM, + ) + + # Separate local and remote users + local_user_ids = set() + remote_server_to_user_ids = {} # type: Dict[str, set] + for user_id in user_ids: + user = UserID.from_string(user_id) + + if self.hs.is_mine(user): + local_user_ids.add(user_id) + else: + remote_server_to_user_ids.setdefault(user.domain, set()) + remote_server_to_user_ids[user.domain].add(user_id) + + # Retrieve info of all local users + user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids) + + # Request info of each remote user from their remote homeserver + for server_name, user_id_set in remote_server_to_user_ids.items(): + # Make a request to the given server about their own users + res = await self.transport_layer.get_info_of_users( + server_name, list(user_id_set) + ) + + for user_id, info in res: + user_id_to_info_dict[user_id] = info + + return 200, user_id_to_info_dict + + def register_servlets(hs, http_server): UserDirectorySearchRestServlet(hs).register(http_server) + SingleUserInfoServlet(hs).register(http_server) + UserInfoServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..b1999d051b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -57,9 +57,12 @@ class VersionsRestServlet(RestServlet): # MSC2326. "org.matrix.label_based_filtering": True, # Implements support for cross signing as described in MSC1756 - "org.matrix.e2e_cross_signing": True, + # "org.matrix.e2e_cross_signing": True, # Implements additional endpoints as described in MSC2432 "org.matrix.msc2432": True, + # Tchap does not currently assume this rule for r0.5.0 + # XXX: Remove this when it does + "m.lazy_load_members": True, }, }, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4a20282d1b..0a890c98cb 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py
@@ -16,10 +16,9 @@ import hmac import logging from hashlib import sha256 +from http import HTTPStatus from os import path -from six.moves import http_client - import jinja2 from jinja2 import TemplateNotFound @@ -219,4 +218,4 @@ class ConsentResource(DirectServeResource): ) if not compare_digest(want_mac, userhmac): - raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect") + raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 3689777266..595849f9d5 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -16,8 +16,7 @@ import logging import os - -from six.moves import urllib +import urllib from twisted.internet import defer from twisted.protocols.basic import FileSender diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index fd10d42f2f..45628c07b4 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -20,8 +20,6 @@ import os import shutil from typing import Dict, Tuple -from six import iteritems - import twisted.internet.error import twisted.web.http from twisted.web.resource import Resource @@ -340,7 +338,7 @@ class MediaRepository(object): with self.media_storage.store_into_file(file_info) as (f, fname, finish): request_path = "/".join( - ("/_matrix/media/v1/download", server_name, media_id) + ("/_matrix/media/r0/download", server_name, media_id) ) try: length, headers = await self.client.get_file( @@ -606,7 +604,7 @@ class MediaRepository(object): thumbnails[(t_width, t_height, r_type)] = r_method # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in iteritems(thumbnails): + for (t_width, t_height, t_type), t_method in thumbnails.items(): # Generate the thumbnail if t_method == "crop": t_byte_source = await defer_to_thread( @@ -705,7 +703,7 @@ class MediaRepositoryResource(Resource): Uploads are POSTed to a resource which returns a token which is used to GET the download:: - => POST /_matrix/media/v1/upload HTTP/1.1 + => POST /_matrix/media/r0/upload HTTP/1.1 Content-Type: <media-type> Content-Length: <content-length> @@ -716,7 +714,7 @@ class MediaRepositoryResource(Resource): { "content_uri": "mxc://<server-name>/<media-id>" } - => GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1 + => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1 <= HTTP/1.1 200 OK Content-Type: <media-type> @@ -727,7 +725,7 @@ class MediaRepositoryResource(Resource): Clients can get thumbnails by supplying a desired width and height and thumbnailing method:: - => GET /_matrix/media/v1/thumbnail/<server_name> + => GET /_matrix/media/r0/thumbnail/<server_name> /<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1 <= HTTP/1.1 200 OK diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 683a79c966..79cb0dddbe 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -17,9 +17,6 @@ import contextlib import logging import os import shutil -import sys - -import six from twisted.internet import defer from twisted.protocols.basic import FileSender @@ -117,12 +114,11 @@ class MediaStorage(object): with open(fname, "wb") as f: yield f, fname, finish except Exception: - t, v, tb = sys.exc_info() try: os.remove(fname) except Exception: pass - six.reraise(t, v, tb) + raise if not finished_called: raise Exception("Finished callback not called") diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f206605727..b4645cd608 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -24,10 +24,7 @@ import shutil import sys import traceback from typing import Dict, Optional - -import six -from six import string_types -from six.moves import urllib_parse as urlparse +from urllib import parse as urlparse from canonicaljson import json @@ -85,6 +82,15 @@ class PreviewUrlResource(DirectServeResource): self.primary_base_path = media_repo.primary_base_path self.media_storage = media_storage + # We run the background jobs if we're the instance specified (or no + # instance is specified, where we assume there is only one instance + # serving media). + instance_running_jobs = hs.config.media.media_instance_running_background_jobs + self._worker_run_media_background_jobs = ( + instance_running_jobs is None + or instance_running_jobs == hs.get_instance_name() + ) + self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist self.url_preview_accept_language = hs.config.url_preview_accept_language @@ -97,9 +103,10 @@ class PreviewUrlResource(DirectServeResource): expiry_ms=60 * 60 * 1000, ) - self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000 - ) + if self._worker_run_media_background_jobs: + self._cleaner_loop = self.clock.looping_call( + self._start_expire_url_cache_data, 10 * 1000 + ) def render_OPTIONS(self, request): request.setHeader(b"Allow", b"OPTIONS, GET") @@ -188,7 +195,7 @@ class PreviewUrlResource(DirectServeResource): # It may be stored as text in the database, not as bytes (such as # PostgreSQL). If so, encode it back before handing it on. og = cache_result["og"] - if isinstance(og, six.text_type): + if isinstance(og, str): og = og.encode("utf8") return og @@ -400,6 +407,8 @@ class PreviewUrlResource(DirectServeResource): """ # TODO: Delete from backup media store + assert self._worker_run_media_background_jobs + now = self.clock.time_msec() logger.debug("Running url preview cache expiry") @@ -631,7 +640,7 @@ def _iterate_over_text(tree, *tags_to_ignore): if el is None: return - if isinstance(el, string_types): + if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: # el.text is the text before the first child, so we can immediately