diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..2e81eeff65 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -118,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b8c95d045a..7c292ef3f9 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -31,7 +31,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
historical_admin_path_patterns,
)
-from synapse.storage.data_stores.main.room import RoomSortOrder
+from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
logger = logging.getLogger(__name__)
@@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ purge = content.get("purge", True)
+ if not isinstance(purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet):
)
# Purge room
- await self.pagination_handler.purge_room(room_id)
+ if purge:
+ await self.pagination_handler.purge_room(room_id)
return (200, ret)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 5934b1fe8b..b210015173 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
try:
- service = await self.auth.get_appservice_by_req(request)
+ service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 970fdd5834..ceaa28c212 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -49,9 +49,7 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(
- state, self.clock.time_msec(), include_user_id=False
- )
+ state = format_user_presence_state(state, self.clock.time_msec())
return 200, state
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 26d5a51cb2..ed82144475 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -444,7 +444,7 @@ class RoomMemberListRestServlet(RestServlet):
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -724,7 +724,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3767a809a4..c8e793ade2 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +15,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+from twisted.internet import defer
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ InteractiveAuthIncompleteError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
@@ -28,6 +41,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
@@ -98,10 +112,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -232,6 +250,7 @@ class PasswordRestServlet(RestServlet):
self.datastore = self.hs.get_datastore()
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -239,18 +258,12 @@ class PasswordRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- if "new_password" in body:
- new_password = body.pop("new_password")
+ new_password = body.pop("new_password", None)
+ if new_password is not None:
if not isinstance(new_password, str) or len(new_password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "new_password_hash" in body:
- raise SynapseError(400, "Unexpected property: new_password_hash")
- body["new_password_hash"] = await self.auth_handler.hash(new_password)
-
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -263,23 +276,56 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ try:
+ (
+ params,
+ session_id,
+ ) = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
user_id = requester.user.to_string()
else:
requester = None
- result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]],
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ try:
+ result, params, session_id = await self.auth_handler.check_ui_auth(
+ [[LoginType.EMAIL_IDENTITY]],
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
@@ -304,19 +350,46 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password_hash"])
- new_password_hash = params["new_password_hash"]
+ # If we have a password in this request, prefer it. Otherwise, there
+ # must be a password hash from an earlier request.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ else:
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password_hash, logout_devices, requester
+ user_id, password_hash, logout_devices, requester
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
return 200, {}
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -420,13 +493,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
@@ -488,13 +565,17 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -579,15 +660,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
# Otherwise show the success template
html = self.config.email_add_threepid_template_success_html_content
@@ -653,7 +729,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -672,6 +749,29 @@ class ThreepidRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ # skip validation if this is a shadow 3PID from an AS
+ if requester.app_service:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
+
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
if threepid_creds is None:
raise SynapseError(
@@ -693,12 +793,36 @@ class ThreepidRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
@@ -709,6 +833,7 @@ class ThreepidAddRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
async def on_POST(self, request):
@@ -744,12 +869,34 @@ class ThreepidAddRestServlet(RestServlet):
validation_session["address"],
validation_session["validated_at"],
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
return 200, {}
raise SynapseError(
400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
@@ -819,6 +966,7 @@ class ThreepidDeleteRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
if not self.hs.config.enable_3pid_changes:
@@ -843,6 +991,12 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
@@ -850,6 +1004,116 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
+
+ defer.returnValue((200, ret))
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ defer.returnValue((200, ret))
+
+
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+ """
+ Raises a SynapseError if a given next_link value is invalid
+
+ next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+ option is either empty or contains a domain that matches the one in the given next_link
+
+ Args:
+ hs: The homeserver object
+ next_link: The next_link value given by the client
+
+ Raises:
+ SynapseError: If the next_link is invalid
+ """
+ valid = True
+
+ # Parse the contents of the URL
+ next_link_parsed = urlparse(next_link)
+
+ # Scheme must not point to the local drive
+ if next_link_parsed.scheme == "file":
+ valid = False
+
+ # If the domain whitelist is set, the domain must be in it
+ if (
+ valid
+ and hs.config.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ ):
+ valid = False
+
+ if not valid:
+ raise SynapseError(
+ 400,
+ "'next_link' domain not included in whitelist, or not http(s)",
+ errcode=Codes.INVALID_PARAM,
+ )
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
@@ -878,4 +1142,6 @@ def register_servlets(hs, http_server):
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..d31ec7c29d 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -17,6 +17,7 @@ import logging
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -39,6 +40,7 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._profile_handler = hs.get_profile_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
@@ -50,6 +52,11 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active([user], not hide_profile, True)
+
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 370742ce59..662833ff49 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import hmac
import logging
+import re
from typing import List, Union
import synapse
@@ -24,6 +26,7 @@ import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
+ InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -127,10 +130,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
- if not check_3pid_allowed(self.hs, "email", email):
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
@@ -199,7 +202,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(country, phone_number)
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -355,15 +360,9 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
- with self.ratelimiter.ratelimit(ip) as wait_deferred:
- await wait_deferred
-
- username = parse_string(request, "username", required=True)
-
- await self.registration_handler.check_username(username)
-
- return 200, {"available": True}
+ # We are not interested in logging in via a username in this deployment.
+ # Simply allow anything here as it won't be used later.
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -387,6 +386,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_enabled = self.hs.config.enable_registration
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -412,29 +412,26 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # we do basic sanity checks here because the auth layer will store these
- # in sessions. Pull out the username/password provided to us.
- if "password" in body:
- password = body.pop("password")
- if not isinstance(password, str) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(password)
-
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "password_hash" in body:
- raise SynapseError(400, "Unexpected property: password_hash")
- body["password_hash"] = await self.auth_handler.hash(password)
-
+ # We don't care about usernames for this deployment. In fact, the act
+ # of checking whether they exist already can leak metadata about
+ # which users are already registered.
+ #
+ # Usernames are already derived via the provided email.
+ # So, if they're not necessary, just ignore them.
+ #
+ # (we do still allow appservices to set them below)
desired_username = None
- if "username" in body:
- if not isinstance(body["username"], str) or len(body["username"]) > 512:
- raise SynapseError(400, "Invalid username")
- desired_username = body["username"]
+
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
- appservice = await self.auth.get_appservice_by_req(request)
+ appservice = self.auth.get_appservice_by_req(request)
+
+ # We need to retrieve the password early in order to pass it to
+ # application service registration
+ # This is specific to shadow server registration of users via an AS
+ password = body.pop("password", None)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -444,7 +441,7 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
- desired_username = body.get("user", desired_username)
+ desired_username = body.get("user", body.get("username"))
# XXX we should check that desired_username is valid. Currently
# we give appservices carte blanche for any insanity in mxids,
@@ -455,26 +452,28 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, str):
result = await self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username, password, desired_display_name, access_token, body
)
return 200, result # we throw for non 200 responses
- # for regular registration, downcase the provided username before
- # attempting to register it. This should mean
- # that people who try to register with upper-case in their usernames
- # don't get a nasty surprise. (Note that we treat username
- # case-insenstively in login, so they are free to carry on imagining
- # that their username is CrAzYh4cKeR if that keeps them happy)
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# == Normal User Registration == (everyone else)
- if not self.hs.config.enable_registration:
+ if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
+ # Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password_hash" not in body:
+ # Pull out the provided password and do basic sanity checks early.
+ #
+ # Note that we remove the password from the body since the auth layer
+ # will store the body in the session and we don't want a plaintext
+ # password store there.
+ if password is not None:
+ if not isinstance(password, str) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(password)
+
+ if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -484,6 +483,7 @@ class RegisterRestServlet(RestServlet):
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
+ password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
@@ -492,35 +492,50 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
-
- if desired_username is not None:
- await self.registration_handler.check_username(
- desired_username,
- guest_access_token=guest_access_token,
- assigned_user_id=registered_user_id,
+ # Extract the previously-hashed password from the session.
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
)
- auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
- )
+ # Check if the user-interactive authentication flows are complete, if
+ # not this will raise a user-interactive auth error.
+ try:
+ auth_result, params, session_id = await self.auth_handler.check_ui_auth(
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth.
+ #
+ # Hash the password and store it with the session since the client
+ # is not required to provide the password again.
+ #
+ # If a password hash was previously stored we will not attempt to
+ # re-hash and store it for efficiency. This assumes the password
+ # does not change throughout the authentication flow, but this
+ # should be fine since the data is meant to be consistent.
+ if not password_hash and password:
+ password_hash = await self.auth_handler.hash(password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
-
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -528,6 +543,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = _map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ await self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -535,12 +624,20 @@ class RegisterRestServlet(RestServlet):
# don't re-register the threepids
registered = False
else:
- # NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password_hash"])
+ # If we have a password in this request, prefer it. Otherwise, there
+ # might be a password hash from an earlier request.
+ if password:
+ password_hash = await self.auth_handler.hash(password)
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
- desired_username = params.get("username", None)
guest_access_token = params.get("guest_access_token", None)
- new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -582,8 +679,9 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password_hash=new_password_hash,
+ password_hash=password_hash,
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -595,8 +693,16 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
- # remember that we've now registered that user account, and with
- # what user ID (since the user may not have specified)
+ if self.hs.config.shadow_server:
+ await self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
+ # Remember that the user account has been registered (and the user
+ # ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -619,11 +725,38 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- async def _do_appservice_registration(self, username, as_token, body):
+ async def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
+
+ if password:
+ # Hash the password
+ #
+ # In mainline hashing of the password was moved further on in the registration
+ # flow, but we need it here for the AS use-case of shadow servers
+ password = await self.auth_handler.hash(password)
+
user_id = await self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- return await self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
+ )
+
+ return result
async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
@@ -635,7 +768,7 @@ class RegisterRestServlet(RestServlet):
(object) params: registration parameters, from which we pull
device_id, initial_device_name and inhibit_login
Returns:
- defer.Deferred: (object) dictionary for response from /register
+ (object) dictionary for response from /register
"""
result = {"user_id": user_id, "home_server": self.hs.hostname}
if not params.get("inhibit_login", False):
@@ -674,6 +807,60 @@ class RegisterRestServlet(RestServlet):
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..6e8300d6a5 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -14,9 +14,17 @@
# limitations under the License.
import logging
+from typing import Dict
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from signedjson.sign import sign_json
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import UserID
from ._base import client_patterns
@@ -35,6 +43,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
async def on_POST(self, request):
"""Searches for users in directory
@@ -61,6 +70,16 @@ class UserDirectorySearchRestServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,5 +95,125 @@ class UserDirectorySearchRestServlet(RestServlet):
return 200, results
+class SingleUserInfoServlet(RestServlet):
+ """
+ Deprecated and replaced by `/users/info`
+
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(SingleUserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ async def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = await self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ return 200, res
+
+ user_id_to_info = await self.store.get_info_for_users([user_id])
+ return 200, user_id_to_info[user_id]
+
+ async def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ user_ids_to_info_dict = await self.store.get_info_for_users([user_id])
+ return user_ids_to_info_dict[user_id]
+
+
+class UserInfoServlet(RestServlet):
+ """Bulk version of `/user/{user_id}/info` endpoint
+
+ GET /users/info HTTP/1.1
+
+ Returns a dictionary of user_id to info dictionary. Supports remote users
+ """
+
+ PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.transport_layer = hs.get_federation_transport_client()
+
+ async def on_POST(self, request):
+ # Ensure the user is authenticated
+ await self.auth.get_user_by_req(request)
+
+ # Extract the user_ids from the request
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, required=["user_ids"])
+
+ user_ids = body["user_ids"]
+ if not isinstance(user_ids, list):
+ raise SynapseError(
+ 400,
+ "'user_ids' must be a list of user ID strings",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # Separate local and remote users
+ local_user_ids = set()
+ remote_server_to_user_ids = {} # type: Dict[str, set]
+ for user_id in user_ids:
+ user = UserID.from_string(user_id)
+
+ if self.hs.is_mine(user):
+ local_user_ids.add(user_id)
+ else:
+ remote_server_to_user_ids.setdefault(user.domain, set())
+ remote_server_to_user_ids[user.domain].add(user_id)
+
+ # Retrieve info of all local users
+ user_id_to_info_dict = await self.store.get_info_for_users(local_user_ids)
+
+ # Request info of each remote user from their remote homeserver
+ for server_name, user_id_set in remote_server_to_user_ids.items():
+ # Make a request to the given server about their own users
+ res = await self.transport_layer.get_info_of_users(
+ server_name, list(user_id_set)
+ )
+
+ for user_id, info in res:
+ user_id_to_info_dict[user_id] = info
+
+ return 200, user_id_to_info_dict
+
+
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ SingleUserInfoServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..b1999d051b 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -57,9 +57,12 @@ class VersionsRestServlet(RestServlet):
# MSC2326.
"org.matrix.label_based_filtering": True,
# Implements support for cross signing as described in MSC1756
- "org.matrix.e2e_cross_signing": True,
+ # "org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4386eb4e72..b3e4d5612e 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -22,8 +22,6 @@ from os import path
import jinja2
from jinja2 import TemplateNotFound
-from twisted.internet import defer
-
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
@@ -135,7 +133,7 @@ class ConsentResource(DirectServeHtmlResource):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id)
+ u = await self.store.get_user_by_id(qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
new file mode 100644
index 0000000000..0170950bf3
--- /dev/null
+++ b/synapse/rest/health.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+
+class HealthResource(Resource):
+ """A resource that does nothing except return a 200 with a body of `OK`,
+ which can be used as a health check.
+
+ Note: `SynapseRequest._should_log_request` ensures that requests to
+ `/health` do not get logged at INFO.
+ """
+
+ isLeaf = 1
+
+ def render_GET(self, request):
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b"OK"
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 9a847130c0..20ddb9550b 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,9 @@
import logging
import os
import urllib
+from typing import Awaitable
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up.
"""
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
- consumer (IConsumer)
+ consumer: The consumer to stream into.
Returns:
- Deferred: Resolves once the response has finished being written
+ Resolves once the response has finished being written
"""
pass
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..6fb4039e98 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
-from typing import Dict, Tuple
+from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
+ Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
- self, media_type, upload_name, content, content_length, auth_user
- ):
+ self,
+ media_type: str,
+ upload_name: str,
+ content: IO,
+ content_length: int,
+ auth_user: str,
+ ) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type(str): The content type of the file
- upload_name(str): The name of the file
+ media_type: The content type of the file
+ upload_name: The name of the file
content: A file like object that is the content to store
- content_length(int): The length of the content
- auth_user(str): The user_id of the uploader
+ content_length: The length of the content
+ auth_user: The user_id of the uploader
Returns:
- Deferred[str]: The mxc url of the stored content
+ The mxc url of the stored content
"""
media_id = random_string(24)
@@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
- async def get_local_media(self, request, media_id, name):
+ async def get_local_media(
+ self, request: Request, media_id: str, name: Optional[str]
+ ) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
- request(twisted.web.http.Request)
- media_id (str): The media ID of the content. (This is the same as
+ request: The incoming request.
+ media_id: The media ID of the content. (This is the same as
the file_id for local content.)
- name (str|None): Optional name that, if specified, will be used as
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
- async def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(
+ self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ ) -> None:
"""Respond to requests for remote media.
Args:
- request(twisted.web.http.Request)
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
- name (str|None): Optional name that, if specified, will be used as
+ request: The incoming request.
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@@ -245,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
- async def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
Returns:
- Deferred[dict]: The media_info of the file
+ The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info
- async def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(
+ self, server_name: str, media_id: str
+ ) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -288,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
- Deferred[(Responder, media_info)]
+ A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(
+ self, server_name: str, media_id: str, file_id: str
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
- server_name (str): Originating server
- media_id (str): The media ID of the content (as defined by the
+ server_name: Originating server
+ media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
- file_id (str): Local file ID
+ file_id: Local file ID
Returns:
- Deferred[MediaInfo]
+ The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
- self, server_name, media_id, file_id, media_type, url_cache=False
- ):
+ self,
+ server_name: Optional[str],
+ media_id: str,
+ file_id: str,
+ media_type: str,
+ url_cache: bool = False,
+ ) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
- server_name (str|None): The server name if remote media, else None if local
- media_id (str): The media ID of the content. (This is the same as
+ server_name: The server name if remote media, else None if local
+ media_id: The media ID of the content. (This is the same as
the file_id for local content)
- file_id (str): Local file ID
- media_type (str): The content type of the file
- url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ file_id: Local file ID
+ media_type: The content type of the file
+ url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
- Deferred[dict]: Dict with "width" and "height" keys of original image
+ Dict with "width" and "height" keys of original image or None if the
+ media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
- return
+ return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@@ -584,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 66bc1c3360..ab1fa705bf 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,13 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
-import inspect
import logging
import os
import shutil
-from typing import Optional
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender
@@ -26,6 +24,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+ from .storage_provider import StorageProviderWrapper
logger = logging.getLogger(__name__)
@@ -34,20 +38,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
- hs (synapse.server.Homeserver)
- local_media_directory (str): Base path where we store media on disk
- filepaths (MediaFilePaths)
- storage_providers ([StorageProvider]): List of StorageProvider that are
- used to fetch and store files.
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
"""
- def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProviderWrapper"],
+ ):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
- async def store_file(self, source, file_info: FileInfo) -> str:
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
@@ -69,7 +78,7 @@ class MediaStorage(object):
return fname
@contextlib.contextmanager
- def store_into_file(self, file_info):
+ def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -85,7 +94,7 @@ class MediaStorage(object):
error.
Args:
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Example:
@@ -105,11 +114,7 @@ class MediaStorage(object):
async def finish():
for provider in self.storage_providers:
- # store_file is supposed to return an Awaitable, but guard
- # against improper implementations.
- result = provider.store_file(path, file_info)
- if inspect.isawaitable(result):
- await result
+ await provider.store_file(path, file_info)
finished_called[0] = True
@@ -143,11 +148,7 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = provider.fetch(path, file_info)
- # Fetch is supposed to return an Awaitable, but guard against
- # improper implementations.
- if inspect.isawaitable(res):
- res = await res
+ res = await provider.fetch(path, file_info) # type: Any
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
@@ -174,11 +175,7 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = provider.fetch(path, file_info)
- # Fetch is supposed to return an Awaitable, but guard against
- # improper implementations.
- if inspect.isawaitable(res):
- res = await res
+ res = await provider.fetch(path, file_info) # type: Any
if res:
with res:
consumer = BackgroundFileConsumer(
@@ -190,17 +187,11 @@ class MediaStorage(object):
raise Exception("file could not be found")
- def _file_info_to_path(self, file_info):
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
-
- Args:
- file_info (FileInfo)
-
- Returns:
- str
"""
if file_info.url_cache:
if file_info.thumbnail:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 13d1a6d2ed..cd8c246594 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,9 +27,7 @@ from typing import Dict, Optional
from urllib import parse as urlparse
import attr
-from canonicaljson import json
-from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
@@ -43,6 +41,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.stringutils import random_string
@@ -228,19 +227,19 @@ class PreviewUrlResource(DirectServeJsonResource):
else:
logger.info("Returning cached response")
- og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
+ og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
- async def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
- url (str):
- user (str):
- ts (int):
+ url: The URL to preview.
+ user: The user requesting the preview.
+ ts: The timestamp requested for the preview.
Returns:
- Deferred[bytes]: json-encoded og data
+ json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
@@ -355,7 +354,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Calculated OG for %s as %s", url, og)
- jsonog = json.dumps(og)
+ jsonog = json_encoder.encode(og)
# store OG in history-aware DB cache
await self.store.store_url_cache(
@@ -586,7 +585,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry")
- if not (await self.store.db.updates.has_completed_background_updates()):
+ if not (await self.store.db_pool.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..18c9ed48d6 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,65 +13,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import os
import shutil
-
-from twisted.internet import defer
+from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
-class StorageProvider(object):
+class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
-
- Returns:
- Deferred
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
"""
- pass
- def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
Returns:
- Deferred(Responder): Returns a Responder if the provider has the file,
- otherwise returns None.
+ Returns a Responder if the provider has the file, otherwise returns None.
"""
- pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
- backend (StorageProvider)
- store_local (bool): Whether to store new local files or not.
- store_synchronous (bool): Whether to wait for file to be successfully
+ backend: The storage provider to wrap.
+ store_local: Whether to store new local files or not.
+ store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
- store_remote (bool): Whether remote media should be uploaded
+ store_remote: Whether remote media should be uploaded
"""
- def __init__(self, backend, store_local, store_synchronous, store_remote):
+ def __init__(
+ self,
+ backend: StorageProvider,
+ store_local: bool,
+ store_synchronous: bool,
+ store_remote: bool,
+ ):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@@ -80,28 +81,38 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
- return defer.succeed(None)
+ return None
if file_info.server_name and not self.store_remote:
- return defer.succeed(None)
+ return None
if self.store_synchronous:
- return self.backend.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
else:
# TODO: Handle errors.
- def store():
+ async def store():
try:
- return self.backend.store_file(path, file_info)
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
except Exception:
logger.exception("Error storing file")
run_in_background(store)
- return defer.succeed(None)
+ return None
- def fetch(self, path, file_info):
- return self.backend.fetch(path, file_info)
+ async def fetch(self, path, file_info):
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.fetch(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +131,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +141,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return defer_to_thread(
+ return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- def fetch(self, path, file_info):
+ async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
|