diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 4a1fc2ec2b..46e458e95b 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
+ password_policy,
read_marker,
receipts,
register,
@@ -118,6 +119,7 @@ class ClientRestResource(JsonResource):
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 81b6bd8816..9eda592de9 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -14,64 +14,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
import logging
import platform
import re
-from six import text_type
-from six.moves import http_client
-
-from twisted.internet import defer
-
import synapse
-from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_integer,
- parse_json_object_from_request,
- parse_string,
-)
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
assert_requester_is_admin,
- assert_user_is_admin,
historical_admin_path_patterns,
)
+from synapse.rest.admin.devices import (
+ DeleteDevicesRestServlet,
+ DeviceRestServlet,
+ DevicesRestServlet,
+)
+from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.rooms import (
+ JoinRoomAliasServlet,
+ ListRoomRestServlet,
+ RoomRestServlet,
+ ShutdownRoomRestServlet,
+)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
-from synapse.rest.admin.users import UserAdminServlet
-from synapse.types import UserID, create_requester
+from synapse.rest.admin.users import (
+ AccountValidityRenewServlet,
+ DeactivateAccountRestServlet,
+ ResetPasswordRestServlet,
+ SearchUsersRestServlet,
+ UserAdminServlet,
+ UserRegisterServlet,
+ UserRestServletV2,
+ UsersRestServlet,
+ UsersRestServletV2,
+ WhoisRestServlet,
+)
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
-class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- yield assert_requester_is_admin(self.auth, request)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- ret = yield self.handlers.admin_handler.get_users()
-
- return 200, ret
-
-
class VersionServlet(RestServlet):
PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
@@ -85,161 +71,6 @@ class VersionServlet(RestServlet):
return 200, self.res
-class UserRegisterServlet(RestServlet):
- """
- Attributes:
- NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
- nonces (dict[str, int]): The nonces that we will accept. A dict of
- nonce to the time it was generated, in int seconds.
- """
-
- PATTERNS = historical_admin_path_patterns("/register")
- NONCE_TIMEOUT = 60
-
- def __init__(self, hs):
- self.handlers = hs.get_handlers()
- self.reactor = hs.get_reactor()
- self.nonces = {}
- self.hs = hs
-
- def _clear_old_nonces(self):
- """
- Clear out old nonces that are older than NONCE_TIMEOUT.
- """
- now = int(self.reactor.seconds())
-
- for k, v in list(self.nonces.items()):
- if now - v > self.NONCE_TIMEOUT:
- del self.nonces[k]
-
- def on_GET(self, request):
- """
- Generate a new nonce.
- """
- self._clear_old_nonces()
-
- nonce = self.hs.get_secrets().token_hex(64)
- self.nonces[nonce] = int(self.reactor.seconds())
- return 200, {"nonce": nonce}
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- self._clear_old_nonces()
-
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
-
- body = parse_json_object_from_request(request)
-
- if "nonce" not in body:
- raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
-
- nonce = body["nonce"]
-
- if nonce not in self.nonces:
- raise SynapseError(400, "unrecognised nonce")
-
- # Delete the nonce, so it can't be reused, even if it's invalid
- del self.nonces[nonce]
-
- if "username" not in body:
- raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON
- )
- else:
- if (
- not isinstance(body["username"], text_type)
- or len(body["username"]) > 512
- ):
- raise SynapseError(400, "Invalid username")
-
- username = body["username"].encode("utf-8")
- if b"\x00" in username:
- raise SynapseError(400, "Invalid username")
-
- if "password" not in body:
- raise SynapseError(
- 400, "password must be specified", errcode=Codes.BAD_JSON
- )
- else:
- if (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
- raise SynapseError(400, "Invalid password")
-
- password = body["password"].encode("utf-8")
- if b"\x00" in password:
- raise SynapseError(400, "Invalid password")
-
- admin = body.get("admin", None)
- user_type = body.get("user_type", None)
-
- if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
- raise SynapseError(400, "Invalid user type")
-
- got_mac = body["mac"]
-
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- digestmod=hashlib.sha1,
- )
- want_mac.update(nonce.encode("utf8"))
- want_mac.update(b"\x00")
- want_mac.update(username)
- want_mac.update(b"\x00")
- want_mac.update(password)
- want_mac.update(b"\x00")
- want_mac.update(b"admin" if admin else b"notadmin")
- if user_type:
- want_mac.update(b"\x00")
- want_mac.update(user_type.encode("utf8"))
- want_mac = want_mac.hexdigest()
-
- if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
- raise SynapseError(403, "HMAC incorrect")
-
- # Reuse the parts of RegisterRestServlet to reduce code duplication
- from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-
- register = RegisterRestServlet(self.hs)
-
- user_id = yield register.registration_handler.register_user(
- localpart=body["username"].lower(),
- password=body["password"],
- admin=bool(admin),
- user_type=user_type,
- )
-
- result = yield register._create_registration_details(user_id, body)
- return 200, result
-
-
-class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
- auth_user = requester.user
-
- if target_user != auth_user:
- yield assert_user_is_admin(self.auth, auth_user)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only whois a local user")
-
- ret = yield self.handlers.admin_handler.get_whois(target_user)
-
- return 200, ret
-
-
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
@@ -255,9 +86,8 @@ class PurgeHistoryRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, room_id, event_id):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
@@ -270,12 +100,12 @@ class PurgeHistoryRestServlet(RestServlet):
event_id = body.get("purge_up_to_event_id")
if event_id is not None:
- event = yield self.store.get_event(event_id)
+ event = await self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- token = yield self.store.get_topological_token_for_event(event_id)
+ token = await self.store.get_topological_token_for_event(event_id)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:
@@ -285,15 +115,13 @@ class PurgeHistoryRestServlet(RestServlet):
400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
)
- stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
+ stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
- r = (
- yield self.store.get_room_event_after_stream_ordering(
- room_id, stream_ordering
- )
+ r = await self.store.get_room_event_before_stream_ordering(
+ room_id, stream_ordering
)
if not r:
- logger.warn(
+ logger.warning(
"[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)",
ts,
@@ -318,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON,
)
- purge_id = yield self.pagination_handler.start_purge_history(
+ purge_id = self.pagination_handler.start_purge_history(
room_id, token, delete_local_events=delete_local_events
)
@@ -339,9 +167,8 @@ class PurgeHistoryStatusRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, purge_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_GET(self, request, purge_id):
+ await assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id)
if purge_status is None:
@@ -350,375 +177,6 @@ class PurgeHistoryStatusRestServlet(RestServlet):
return 200, purge_status.asdict()
-class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self._deactivate_account_handler = hs.get_deactivate_account_handler()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- yield assert_requester_is_admin(self.auth, request)
- body = parse_json_object_from_request(request, allow_empty_body=True)
- erase = body.get("erase", False)
- if not isinstance(erase, bool):
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Param 'erase' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- UserID.from_string(target_user_id)
-
- result = yield self._deactivate_account_handler.deactivate_account(
- target_user_id, erase
- )
- if result:
- id_server_unbind_result = "success"
- else:
- id_server_unbind_result = "no-support"
-
- return 200, {"id_server_unbind_result": id_server_unbind_result}
-
-
-class ShutdownRoomRestServlet(RestServlet):
- """Shuts down a room by removing all local users from the room and blocking
- all future invites and joins to the room. Any local aliases will be repointed
- to a new room created by `new_room_user_id` and kicked users will be auto
- joined to the new room.
- """
-
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
-
- DEFAULT_MESSAGE = (
- "Sharing illegal content on this server is not permitted and rooms in"
- " violation will be blocked."
- )
-
- def __init__(self, hs):
- self.hs = hs
- self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
- self._room_creation_handler = hs.get_room_creation_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ["new_room_user_id"])
- new_room_user_id = content["new_room_user_id"]
-
- room_creator_requester = create_requester(new_room_user_id)
-
- message = content.get("message", self.DEFAULT_MESSAGE)
- room_name = content.get("room_name", "Content Violation Notification")
-
- info = yield self._room_creation_handler.create_room(
- room_creator_requester,
- config={
- "preset": "public_chat",
- "name": room_name,
- "power_level_content_override": {"users_default": -10},
- },
- ratelimit=False,
- )
- new_room_id = info["room_id"]
-
- requester_user_id = requester.user.to_string()
-
- logger.info(
- "Shutting down room %r, joining to new room: %r", room_id, new_room_id
- )
-
- # This will work even if the room is already blocked, but that is
- # desirable in case the first attempt at blocking the room failed below.
- yield self.store.block_room(room_id, requester_user_id)
-
- users = yield self.state.get_current_users_in_room(room_id)
- kicked_users = []
- failed_to_kick_users = []
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
- continue
-
- logger.info("Kicking %r from %r...", user_id, room_id)
-
- try:
- target_requester = create_requester(user_id)
- yield self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=room_id,
- action=Membership.LEAVE,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- yield self.room_member_handler.forget(target_requester.user, room_id)
-
- yield self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- kicked_users.append(user_id)
- except Exception:
- logger.exception(
- "Failed to leave old room and join new room for %r", user_id
- )
- failed_to_kick_users.append(user_id)
-
- yield self.event_creation_handler.create_and_send_nonmember_event(
- room_creator_requester,
- {
- "type": "m.room.message",
- "content": {"body": message, "msgtype": "m.text"},
- "room_id": new_room_id,
- "sender": new_room_user_id,
- },
- ratelimit=False,
- )
-
- aliases_for_room = yield self.store.get_aliases_for_room(room_id)
-
- yield self.store.update_aliases_for_room(
- room_id, new_room_id, requester_user_id
- )
-
- return (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
-
-
-class ResetPasswordRestServlet(RestServlet):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/reset_password/
- @user:to_reset_password?access_token=admin_access_token
- JsonBodyToSend:
- {
- "new_password": "secret"
- }
- Returns:
- 200 OK with empty object if success otherwise an error.
- """
-
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self._set_password_handler = hs.get_set_password_handler()
-
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- """
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- UserID.from_string(target_user_id)
-
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["new_password"])
- new_password = params["new_password"]
-
- yield self._set_password_handler.set_password(
- target_user_id, new_password, requester
- )
- return 200, {}
-
-
-class GetUsersPaginatedRestServlet(RestServlet):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token&start=0&limit=10
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
-
- PATTERNS = historical_admin_path_patterns(
- "/users_paginate/(?P<target_user_id>[^/]*)"
- )
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, target_user_id):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- """
- yield assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- order = "name" # order by name in user table
- start = parse_integer(request, "start", required=True)
- limit = parse_integer(request, "limit", required=True)
-
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
- return 200, ret
-
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- """Post request to get specific number of users from Synapse..
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token
- JsonBodyToSend:
- {
- "start": "0",
- "limit": "10
- }
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
- yield assert_requester_is_admin(self.auth, request)
- UserID.from_string(target_user_id)
-
- order = "name" # order by name in user table
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["limit", "start"])
- limit = params["limit"]
- start = params["start"]
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
- return 200, ret
-
-
-class SearchUsersRestServlet(RestServlet):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/search_users/
- @admin:user?access_token=admin_access_token&term=alice
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
-
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, target_user_id):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have a administrator access in Synapse.
- """
- yield assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- # To allow all users to get the users list
- # if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- term = parse_string(request, "term", required=True)
- logger.info("term: %s ", term)
-
- ret = yield self.handlers.admin_handler.search_users(term)
- return 200, ret
-
-
-class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups
- """
-
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
-
- def __init__(self, hs):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- if not self.is_mine_id(group_id):
- raise SynapseError(400, "Can only delete local groups")
-
- yield self.group_server.delete_group(group_id, requester.user.to_string())
- return 200, {}
-
-
-class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- self.hs = hs
- self.account_activity_handler = hs.get_account_validity_handler()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield assert_requester_is_admin(self.auth, request)
-
- body = parse_json_object_from_request(request)
-
- if "user_id" not in body:
- raise SynapseError(400, "Missing property 'user_id' in the request body")
-
- expiration_ts = yield self.account_activity_handler.renew_account_for_user(
- body["user_id"],
- body.get("expiration_ts"),
- not body.get("enable_renewal_emails", True),
- )
-
- res = {"expiration_ts": expiration_ts}
- return 200, res
-
-
########################################################################################
#
# please don't add more servlets here: this file is already long and unwieldy. Put
@@ -740,10 +198,18 @@ def register_servlets(hs, http_server):
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
+ ListRoomRestServlet(hs).register(http_server)
+ RoomRestServlet(hs).register(http_server)
+ JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
+ UserRestServletV2(hs).register(http_server)
+ UsersRestServletV2(hs).register(http_server)
+ DeviceRestServlet(hs).register(http_server)
+ DevicesRestServlet(hs).register(http_server)
+ DeleteDevicesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
@@ -754,7 +220,6 @@ def register_servlets_for_client_rest_resource(hs, http_server):
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
- GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 5a9b08d3ef..d82eaf5e38 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -15,9 +15,11 @@
import re
-from twisted.internet import defer
+import twisted.web.server
+import synapse.api.auth
from synapse.api.errors import AuthError
+from synapse.types import UserID
def historical_admin_path_patterns(path_regex):
@@ -31,7 +33,7 @@ def historical_admin_path_patterns(path_regex):
Note that this should only be used for existing endpoints: new ones should just
register for the /_synapse/admin path.
"""
- return list(
+ return [
re.compile(prefix + path_regex)
for prefix in (
"^/_synapse/admin/v1",
@@ -39,46 +41,50 @@ def historical_admin_path_patterns(path_regex):
"^/_matrix/client/unstable/admin",
"^/_matrix/client/r0/admin",
)
- )
+ ]
-@defer.inlineCallbacks
-def assert_requester_is_admin(auth, request):
- """Verify that the requester is an admin user
-
- WARNING: MAKE SURE YOU YIELD ON THE RESULT!
+def admin_patterns(path_regex: str):
+ """Returns the list of patterns for an admin endpoint
Args:
- auth (synapse.api.auth.Auth):
- request (twisted.web.server.Request): incoming request
+ path_regex: The regex string to match. This should NOT have a ^
+ as this will be prefixed.
Returns:
- Deferred
+ A list of regex patterns.
+ """
+ admin_prefix = "^/_synapse/admin/v1"
+ patterns = [re.compile(admin_prefix + path_regex)]
+ return patterns
+
+
+async def assert_requester_is_admin(
+ auth: synapse.api.auth.Auth, request: twisted.web.server.Request
+) -> None:
+ """Verify that the requester is an admin user
+
+ Args:
+ auth: api.auth.Auth singleton
+ request: incoming request
Raises:
- AuthError if the requester is not an admin
+ AuthError if the requester is not a server admin
"""
- requester = yield auth.get_user_by_req(request)
- yield assert_user_is_admin(auth, requester.user)
+ requester = await auth.get_user_by_req(request)
+ await assert_user_is_admin(auth, requester.user)
-@defer.inlineCallbacks
-def assert_user_is_admin(auth, user_id):
+async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
- WARNING: MAKE SURE YOU YIELD ON THE RESULT!
-
Args:
- auth (synapse.api.auth.Auth):
- user_id (UserID):
-
- Returns:
- Deferred
+ auth: api.auth.Auth singleton
+ user_id: user to check
Raises:
- AuthError if the user is not an admin
+ AuthError if the user is not a server admin
"""
-
- is_admin = yield auth.is_server_admin(user_id)
+ is_admin = await auth.is_server_admin(user_id)
if not is_admin:
raise AuthError(403, "You are not a server admin")
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
new file mode 100644
index 0000000000..8d32677339
--- /dev/null
+++ b/synapse/rest/admin/devices.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from synapse.api.errors import NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin._base import assert_requester_is_admin
+from synapse.types import UserID
+
+logger = logging.getLogger(__name__)
+
+
+class DeviceRestServlet(RestServlet):
+ """
+ Get, update or delete the given user's device
+ """
+
+ PATTERNS = (
+ re.compile(
+ "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
+ ),
+ )
+
+ def __init__(self, hs):
+ super(DeviceRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id, device_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ device = await self.device_handler.get_device(
+ target_user.to_string(), device_id
+ )
+ return 200, device
+
+ async def on_DELETE(self, request, user_id, device_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ await self.device_handler.delete_device(target_user.to_string(), device_id)
+ return 200, {}
+
+ async def on_PUT(self, request, user_id, device_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ await self.device_handler.update_device(
+ target_user.to_string(), device_id, body
+ )
+ return 200, {}
+
+
+class DevicesRestServlet(RestServlet):
+ """
+ Retrieve the given user's devices
+ """
+
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ devices = await self.device_handler.get_devices_by_user(target_user.to_string())
+ return 200, {"devices": devices}
+
+
+class DeleteDevicesRestServlet(RestServlet):
+ """
+ API for bulk deletion of devices. Accepts a JSON object with a devices
+ key which lists the device_ids to delete.
+ """
+
+ PATTERNS = (
+ re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
+ )
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ body = parse_json_object_from_request(request, allow_empty_body=False)
+ assert_params_in_dict(body, ["devices"])
+
+ await self.device_handler.delete_devices(
+ target_user.to_string(), body["devices"]
+ )
+ return 200, {}
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
new file mode 100644
index 0000000000..0b54ca09f4
--- /dev/null
+++ b/synapse/rest/admin/groups.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.rest.admin._base import (
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DeleteGroupAdminRestServlet(RestServlet):
+ """Allows deleting of local groups
+ """
+
+ PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Can only delete local groups")
+
+ await self.group_server.delete_group(group_id, requester.user.to_string())
+ return 200, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ed7086d09c..ee75095c0e 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_integer
from synapse.rest.admin._base import (
@@ -34,24 +32,85 @@ class QuarantineMediaInRoom(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ PATTERNS = (
+ historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ +
+ # This path kept around for legacy reasons
+ historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ )
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ async def on_POST(self, request, room_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining room: %s", room_id)
- num_quarantined = yield self.store.quarantine_media_ids_in_room(
+ # Quarantine all media in this room
+ num_quarantined = await self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string()
)
return 200, {"num_quarantined": num_quarantined}
+class QuarantineMediaByUser(RestServlet):
+ """Quarantines all local media by a given user so that no one can download it via
+ this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/user/(?P<user_id>[^/]+)/media/quarantine"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by user: %s", user_id)
+
+ # Quarantine all media this user has uploaded
+ num_quarantined = await self.store.quarantine_media_ids_by_user(
+ user_id, requester.user.to_string()
+ )
+
+ return 200, {"num_quarantined": num_quarantined}
+
+
+class QuarantineMediaByID(RestServlet):
+ """Quarantines local or remote media by a given ID so that no one can download
+ it via this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, server_name: str, media_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
+
+ # Quarantine this media id
+ await self.store.quarantine_media_by_id(
+ server_name, media_id, requester.user.to_string()
+ )
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
@@ -62,14 +121,13 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
- local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+ local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
return 200, {"local": local_mxcs, "remote": remote_mxcs}
@@ -81,14 +139,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
logger.info("before_ts: %r", before_ts)
- ret = yield self.media_repository.delete_old_remote_media(before_ts)
+ ret = await self.media_repository.delete_old_remote_media(before_ts)
return 200, ret
@@ -99,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server):
"""
PurgeMediaCacheRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ QuarantineMediaByID(hs).register(http_server)
+ QuarantineMediaByUser(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
new file mode 100644
index 0000000000..8173baef8f
--- /dev/null
+++ b/synapse/rest/admin/rooms.py
@@ -0,0 +1,364 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import List, Optional
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.admin._base import (
+ admin_patterns,
+ assert_requester_is_admin,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+from synapse.storage.data_stores.main.room import RoomSortOrder
+from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.util.async_helpers import maybe_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class ShutdownRoomRestServlet(RestServlet):
+ """Shuts down a room by removing all local users from the room and blocking
+ all future invites and joins to the room. Any local aliases will be repointed
+ to a new room created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violation will be blocked."
+ )
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+ self._replication = hs.get_replication_data_handler()
+
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ message = content.get("message", self.DEFAULT_MESSAGE)
+ room_name = content.get("room_name", "Content Violation Notification")
+
+ info, stream_id = await self._room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": "public_chat",
+ "name": room_name,
+ "power_level_content_override": {"users_default": -10},
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ requester_user_id = requester.user.to_string()
+
+ logger.info(
+ "Shutting down room %r, joining to new room: %r", room_id, new_room_id
+ )
+
+ # This will work even if the room is already blocked, but that is
+ # desirable in case the first attempt at blocking the room failed below.
+ await self.store.block_room(room_id, requester_user_id)
+
+ # We now wait for the create room to come back in via replication so
+ # that we can assume that all the joins/invites have propogated before
+ # we try and auto join below.
+ #
+ # TODO: Currently the events stream is written to from master
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
+ users = await self.state.get_current_users_in_room(room_id)
+ kicked_users = []
+ failed_to_kick_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ try:
+ target_requester = create_requester(user_id)
+ _, stream_id = await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ # Wait for leave to come in over replication before trying to forget.
+ await self._replication.wait_for_stream_position(
+ self.hs.config.worker.writers.events, "events", stream_id
+ )
+
+ await self.room_member_handler.forget(target_requester.user, room_id)
+
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ kicked_users.append(user_id)
+ except Exception:
+ logger.exception(
+ "Failed to leave old room and join new room for %r", user_id
+ )
+ failed_to_kick_users.append(user_id)
+
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ aliases_for_room = await maybe_awaitable(
+ self.store.get_aliases_for_room(room_id)
+ )
+
+ await self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+
+ return (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
+ )
+
+
+class ListRoomRestServlet(RestServlet):
+ """
+ List all rooms that are known to the homeserver. Results are returned
+ in a dictionary containing room information. Supports pagination.
+ """
+
+ PATTERNS = admin_patterns("/rooms$")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ # Extract query parameters
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
+ if order_by not in (
+ RoomSortOrder.ALPHABETICAL.value,
+ RoomSortOrder.SIZE.value,
+ RoomSortOrder.NAME.value,
+ RoomSortOrder.CANONICAL_ALIAS.value,
+ RoomSortOrder.JOINED_MEMBERS.value,
+ RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
+ RoomSortOrder.VERSION.value,
+ RoomSortOrder.CREATOR.value,
+ RoomSortOrder.ENCRYPTION.value,
+ RoomSortOrder.FEDERATABLE.value,
+ RoomSortOrder.PUBLIC.value,
+ RoomSortOrder.JOIN_RULES.value,
+ RoomSortOrder.GUEST_ACCESS.value,
+ RoomSortOrder.HISTORY_VISIBILITY.value,
+ RoomSortOrder.STATE_EVENTS.value,
+ ):
+ raise SynapseError(
+ 400,
+ "Unknown value for order_by: %s" % (order_by,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ search_term = parse_string(request, "search_term")
+ if search_term == "":
+ raise SynapseError(
+ 400,
+ "search_term cannot be an empty string",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ direction = parse_string(request, "dir", default="f")
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ reverse_order = True if direction == "b" else False
+
+ # Return list of rooms according to parameters
+ rooms, total_rooms = await self.store.get_rooms_paginate(
+ start, limit, order_by, reverse_order, search_term
+ )
+ response = {
+ # next_token should be opaque, so return a value the client can parse
+ "offset": start,
+ "rooms": rooms,
+ "total_rooms": total_rooms,
+ }
+
+ # Are there more rooms to paginate through after this?
+ if (start + limit) < total_rooms:
+ # There are. Calculate where the query should start from next time
+ # to get the next part of the list
+ response["next_batch"] = start + limit
+
+ # Is it possible to paginate backwards? Check if we currently have an
+ # offset
+ if start > 0:
+ if start > limit:
+ # Going back one iteration won't take us to the start.
+ # Calculate new offset
+ response["prev_batch"] = start - limit
+ else:
+ response["prev_batch"] = 0
+
+ return 200, response
+
+
+class RoomRestServlet(RestServlet):
+ """Get room details.
+
+ TODO: Add on_POST to allow room creation without joining the room
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, room_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ ret = await self.store.get_room_with_stats(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ return 200, ret
+
+
+class JoinRoomAliasServlet(RestServlet):
+
+ PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.admin_handler = hs.get_handlers().admin_handler
+ self.state_handler = hs.get_state_handler()
+
+ async def on_POST(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+
+ assert_params_in_dict(content, ["user_id"])
+ target_user = UserID.from_string(content["user_id"])
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "This endpoint can only be used with local users")
+
+ if not await self.admin_handler.get_user(target_user):
+ raise NotFoundError("User not found")
+
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ try:
+ remote_room_hosts = [
+ x.decode("ascii") for x in request.args[b"server_name"]
+ ] # type: Optional[List[str]]
+ except Exception:
+ remote_room_hosts = None
+ elif RoomAlias.is_valid(room_identifier):
+ handler = self.room_member_handler
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ fake_requester = create_requester(target_user)
+
+ # send invite if room has "JoinRules.INVITE"
+ room_state = await self.state_handler.get_current_state(room_id)
+ join_rules_event = room_state.get((EventTypes.JoinRules, ""))
+ if join_rules_event:
+ if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=fake_requester.user,
+ room_id=room_id,
+ action="invite",
+ remote_room_hosts=remote_room_hosts,
+ ratelimit=False,
+ )
+
+ await self.room_member_handler.update_membership(
+ requester=fake_requester,
+ target=fake_requester.user,
+ room_id=room_id,
+ action="join",
+ remote_room_hosts=remote_room_hosts,
+ ratelimit=False,
+ )
+
+ return 200, {"room_id": room_id}
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index ae2cbe2e0a..6e9a874121 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -14,8 +14,6 @@
# limitations under the License.
import re
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -69,9 +67,8 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- @defer.inlineCallbacks
- def on_POST(self, request, txn_id=None):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, txn_id=None):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message)
@@ -85,7 +82,7 @@ class SendServerNoticeServlet(RestServlet):
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = yield self.snm.send_notice(
+ event = await self.snm.send_notice(
user_id=body["user_id"],
type=event_type,
state_key=state_key,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 9720a3bab0..fefc8f71fa 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -12,19 +12,607 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import hmac
+import logging
import re
-from twisted.internet import defer
+from six import text_type
+from six.moves import http_client
-from synapse.api.errors import SynapseError
+from synapse.api.constants import UserTypes
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_boolean,
+ parse_integer,
parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.admin._base import (
+ assert_requester_is_admin,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
)
-from synapse.rest.admin import assert_requester_is_admin, assert_user_is_admin
from synapse.types import UserID
+logger = logging.getLogger(__name__)
+
+
+class UsersRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ ret = await self.store.get_users()
+
+ return 200, ret
+
+
+class UsersRestServletV2(RestServlet):
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+
+ """Get request to list all local users.
+ This needs user to have administrator access in Synapse.
+
+ GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
+
+ returns:
+ 200 OK with list of users if success otherwise an error.
+
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `guests` can be used to exclude guest users.
+ The parameter `deactivated` can be used to include deactivated users.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ user_id = parse_string(request, "user_id", default=None)
+ guests = parse_boolean(request, "guests", default=True)
+ deactivated = parse_boolean(request, "deactivated", default=False)
+
+ users, total = await self.store.get_users_paginate(
+ start, limit, user_id, guests, deactivated
+ )
+ ret = {"users": users, "total": total}
+ if len(users) >= limit:
+ ret["next_token"] = str(start + len(users))
+
+ return 200, ret
+
+
+class UserRestServletV2(RestServlet):
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
+
+ """Get request to list user details.
+ This needs user to have administrator access in Synapse.
+
+ GET /_synapse/admin/v2/users/<user_id>
+
+ returns:
+ 200 OK with user details if success otherwise an error.
+
+ Put request to allow an administrator to add or modify a user.
+ This needs user to have administrator access in Synapse.
+ We use PUT instead of POST since we already know the id of the user
+ object to create. POST could be used to create guests.
+
+ PUT /_synapse/admin/v2/users/<user_id>
+ {
+ "password": "secret",
+ "displayname": "User"
+ }
+
+ returns:
+ 201 OK with new user object if user was created or
+ 200 OK with modified user object if user was modified
+ otherwise an error.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+ self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
+ self.profile_handler = hs.get_profile_handler()
+ self.set_password_handler = hs.get_set_password_handler()
+ self.deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.registration_handler = hs.get_registration_handler()
+ self.pusher_pool = hs.get_pusherpool()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ ret = await self.admin_handler.get_user(target_user)
+
+ if not ret:
+ raise NotFoundError("User not found")
+
+ return 200, ret
+
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ target_user = UserID.from_string(user_id)
+ body = parse_json_object_from_request(request)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "This endpoint can only be used with local users")
+
+ user = await self.admin_handler.get_user(target_user)
+ user_id = target_user.to_string()
+
+ if user: # modify user
+ if "displayname" in body:
+ await self.profile_handler.set_displayname(
+ target_user, requester, body["displayname"], True
+ )
+
+ if "threepids" in body:
+ # check for required parameters for each threepid
+ for threepid in body["threepids"]:
+ assert_params_in_dict(threepid, ["medium", "address"])
+
+ # remove old threepids from user
+ threepids = await self.store.user_get_threepids(user_id)
+ for threepid in threepids:
+ try:
+ await self.auth_handler.delete_threepid(
+ user_id, threepid["medium"], threepid["address"], None
+ )
+ except Exception:
+ logger.exception("Failed to remove threepids")
+ raise SynapseError(500, "Failed to remove threepids")
+
+ # add new threepids to user
+ current_time = self.hs.get_clock().time_msec()
+ for threepid in body["threepids"]:
+ await self.auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], current_time
+ )
+
+ if "avatar_url" in body and type(body["avatar_url"]) == str:
+ await self.profile_handler.set_avatar_url(
+ target_user, requester, body["avatar_url"], True
+ )
+
+ if "admin" in body:
+ set_admin_to = bool(body["admin"])
+ if set_admin_to != user["admin"]:
+ auth_user = requester.user
+ if target_user == auth_user and not set_admin_to:
+ raise SynapseError(400, "You may not demote yourself.")
+
+ await self.store.set_server_admin(target_user, set_admin_to)
+
+ if "password" in body:
+ if (
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
+ ):
+ raise SynapseError(400, "Invalid password")
+ else:
+ new_password = body["password"]
+ logout_devices = True
+
+ new_password_hash = await self.auth_handler.hash(new_password)
+
+ await self.set_password_handler.set_password(
+ target_user.to_string(),
+ new_password_hash,
+ logout_devices,
+ requester,
+ )
+
+ if "deactivated" in body:
+ deactivate = body["deactivated"]
+ if not isinstance(deactivate, bool):
+ raise SynapseError(
+ 400, "'deactivated' parameter is not of type boolean"
+ )
+
+ if deactivate and not user["deactivated"]:
+ await self.deactivate_account_handler.deactivate_account(
+ target_user.to_string(), False
+ )
+
+ user = await self.admin_handler.get_user(target_user)
+ return 200, user
+
+ else: # create user
+ password = body.get("password")
+ password_hash = None
+ if password is not None:
+ if not isinstance(password, text_type) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ password_hash = await self.auth_handler.hash(password)
+
+ admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+ displayname = body.get("displayname", None)
+ threepids = body.get("threepids", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
+ user_id = await self.registration_handler.register_user(
+ localpart=target_user.localpart,
+ password_hash=password_hash,
+ admin=bool(admin),
+ default_display_name=displayname,
+ user_type=user_type,
+ by_admin=True,
+ )
+
+ if "threepids" in body:
+ # check for required parameters for each threepid
+ for threepid in body["threepids"]:
+ assert_params_in_dict(threepid, ["medium", "address"])
+
+ current_time = self.hs.get_clock().time_msec()
+ for threepid in body["threepids"]:
+ await self.auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], current_time
+ )
+ if (
+ self.hs.config.email_enable_notifs
+ and self.hs.config.email_notif_for_new_users
+ ):
+ await self.pusher_pool.add_pusher(
+ user_id=user_id,
+ access_token=None,
+ kind="email",
+ app_id="m.email",
+ app_display_name="Email Notifications",
+ device_display_name=threepid["address"],
+ pushkey=threepid["address"],
+ lang=None, # We don't know a user's language here
+ data={},
+ )
+
+ if "avatar_url" in body and type(body["avatar_url"]) == str:
+ await self.profile_handler.set_avatar_url(
+ user_id, requester, body["avatar_url"], True
+ )
+
+ ret = await self.admin_handler.get_user(target_user)
+
+ return 201, ret
+
+
+class UserRegisterServlet(RestServlet):
+ """
+ Attributes:
+ NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
+ nonces (dict[str, int]): The nonces that we will accept. A dict of
+ nonce to the time it was generated, in int seconds.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/register")
+ NONCE_TIMEOUT = 60
+
+ def __init__(self, hs):
+ self.auth_handler = hs.get_auth_handler()
+ self.reactor = hs.get_reactor()
+ self.nonces = {}
+ self.hs = hs
+
+ def _clear_old_nonces(self):
+ """
+ Clear out old nonces that are older than NONCE_TIMEOUT.
+ """
+ now = int(self.reactor.seconds())
+
+ for k, v in list(self.nonces.items()):
+ if now - v > self.NONCE_TIMEOUT:
+ del self.nonces[k]
+
+ def on_GET(self, request):
+ """
+ Generate a new nonce.
+ """
+ self._clear_old_nonces()
+
+ nonce = self.hs.get_secrets().token_hex(64)
+ self.nonces[nonce] = int(self.reactor.seconds())
+ return 200, {"nonce": nonce}
+
+ async def on_POST(self, request):
+ self._clear_old_nonces()
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ body = parse_json_object_from_request(request)
+
+ if "nonce" not in body:
+ raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+
+ nonce = body["nonce"]
+
+ if nonce not in self.nonces:
+ raise SynapseError(400, "unrecognised nonce")
+
+ # Delete the nonce, so it can't be reused, even if it's invalid
+ del self.nonces[nonce]
+
+ if "username" not in body:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["username"], text_type)
+ or len(body["username"]) > 512
+ ):
+ raise SynapseError(400, "Invalid username")
+
+ username = body["username"].encode("utf-8")
+ if b"\x00" in username:
+ raise SynapseError(400, "Invalid username")
+
+ if "password" not in body:
+ raise SynapseError(
+ 400, "password must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ password = body["password"]
+ if not isinstance(password, text_type) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+
+ password_bytes = password.encode("utf-8")
+ if b"\x00" in password_bytes:
+ raise SynapseError(400, "Invalid password")
+
+ password_hash = await self.auth_handler.hash(password)
+
+ admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
+ got_mac = body["mac"]
+
+ want_mac_builder = hmac.new(
+ key=self.hs.config.registration_shared_secret.encode(),
+ digestmod=hashlib.sha1,
+ )
+ want_mac_builder.update(nonce.encode("utf8"))
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(username)
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(password_bytes)
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(user_type.encode("utf8"))
+
+ want_mac = want_mac_builder.hexdigest()
+
+ if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
+ raise SynapseError(403, "HMAC incorrect")
+
+ # Reuse the parts of RegisterRestServlet to reduce code duplication
+ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+
+ register = RegisterRestServlet(self.hs)
+
+ user_id = await register.registration_handler.register_user(
+ localpart=body["username"].lower(),
+ password_hash=password_hash,
+ admin=bool(admin),
+ user_type=user_type,
+ by_admin=True,
+ )
+
+ result = await register._create_registration_details(user_id, body)
+ return 200, result
+
+
+class WhoisRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ requester = await self.auth.get_user_by_req(request)
+ auth_user = requester.user
+
+ if target_user != auth_user:
+ await assert_user_is_admin(self.auth, auth_user)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only whois a local user")
+
+ ret = await self.handlers.admin_handler.get_whois(target_user)
+
+ return 200, ret
+
+
+class DeactivateAccountRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, target_user_id):
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ UserID.from_string(target_user_id)
+
+ result = await self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase
+ )
+ if result:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ return 200, {"id_server_unbind_result": id_server_unbind_result}
+
+
+class AccountValidityRenewServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.account_activity_handler = hs.get_account_validity_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ body = parse_json_object_from_request(request)
+
+ if "user_id" not in body:
+ raise SynapseError(400, "Missing property 'user_id' in the request body")
+
+ expiration_ts = await self.account_activity_handler.renew_account_for_user(
+ body["user_id"],
+ body.get("expiration_ts"),
+ not body.get("enable_renewal_emails", True),
+ )
+
+ res = {"expiration_ts": expiration_ts}
+ return 200, res
+
+
+class ResetPasswordRestServlet(RestServlet):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/reset_password/
+ @user:to_reset_password?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "new_password": "secret"
+ }
+ Returns:
+ 200 OK with empty object if success otherwise an error.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/reset_password/(?P<target_user_id>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+ self._set_password_handler = hs.get_set_password_handler()
+
+ async def on_POST(self, request, target_user_id):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ """
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ UserID.from_string(target_user_id)
+
+ params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
+ new_password = params["new_password"]
+ logout_devices = params.get("logout_devices", True)
+
+ new_password_hash = await self.auth_handler.hash(new_password)
+
+ await self._set_password_handler.set_password(
+ target_user_id, new_password_hash, logout_devices, requester
+ )
+ return 200, {}
+
+
+class SearchUsersRestServlet(RestServlet):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/search_users/
+ @admin:user?access_token=admin_access_token&term=alice
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, target_user_id):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have a administrator access in Synapse.
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(target_user_id)
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ term = parse_string(request, "term", required=True)
+ logger.info("term: %s ", term)
+
+ ret = await self.handlers.store.search_users(term)
+ return 200, ret
+
class UserAdminServlet(RestServlet):
"""
@@ -52,31 +640,28 @@ class UserAdminServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>@[^/]*)/admin$"),)
+ PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
def __init__(self, hs):
self.hs = hs
+ self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Only local users can be admins of this homeserver")
- is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user)
- is_admin = bool(is_admin)
+ is_admin = await self.store.is_server_admin(target_user)
return 200, {"admin": is_admin}
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
target_user = UserID.from_string(user_id)
@@ -93,8 +678,6 @@ class UserAdminServlet(RestServlet):
if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.")
- yield self.handlers.admin_handler.set_user_server_admin(
- target_user, set_admin_to
- )
+ await self.store.set_server_admin(target_user, set_admin_to)
return 200, {}
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 4ea3666874..5934b1fe8b 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import (
AuthError,
Codes,
@@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_alias):
+ async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
dir_handler = self.handlers.directory_handler
- res = yield dir_handler.get_association(room_alias)
+ res = await dir_handler.get_association(room_alias)
return 200, res
- @defer.inlineCallbacks
- def on_PUT(self, request, room_alias):
+ async def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
@@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet):
# TODO(erikj): Check types.
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Room does not exist")
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
- yield self.handlers.directory_handler.create_association(
+ await self.handlers.directory_handler.create_association(
requester, room_alias, room_id, servers
)
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_alias):
+ async def on_DELETE(self, request, room_alias):
dir_handler = self.handlers.directory_handler
try:
- service = yield self.auth.get_appservice_by_req(request)
+ service = await self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_appservice_association(service, room_alias)
+ await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
"Application service at %s deleted alias %s",
service.url,
@@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet):
# fallback to default user behaviour if they aren't an AS
pass
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_association(requester, room_alias)
+ await dir_handler.delete_association(requester, room_alias)
logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
@@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- room = yield self.store.get_room(room_id)
+ async def on_GET(self, request, room_id):
+ room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- yield self.handlers.directory_handler.edit_published_room_list(
+ await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, visibility
)
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.handlers.directory_handler.edit_published_room_list(
+ await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, "private"
)
@@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
- @defer.inlineCallbacks
- def _edit(self, request, network_id, room_id, visibility):
- requester = yield self.auth.get_user_by_req(request)
+ async def _edit(self, request, network_id, room_id, visibility):
+ requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
403, "Only appservices can edit the appservice published room list"
)
- yield self.handlers.directory_handler.edit_published_appservice_room_list(
+ await self.handlers.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility
)
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 6651b4cf07..25effd0261 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -16,8 +16,6 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet):
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
room_id = None
if is_guest:
@@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet):
as_client_event = b"raw" not in request.args
- chunk = yield self.event_stream_handler.get_stream(
+ chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@@ -73,7 +70,6 @@ class EventStreamRestServlet(RestServlet):
return 200, {}
-# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
@@ -81,16 +77,16 @@ class EventRestServlet(RestServlet):
super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
+ self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request, event_id):
- requester = yield self.auth.get_user_by_req(request)
- event = yield self.event_handler.get_event(requester.user, None, event_id)
+ async def on_GET(self, request, event_id):
+ requester = await self.auth.get_user_by_req(request)
+ event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
+ event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
else:
return 404, "Event not found."
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 2da3cd7511..910b3b4eeb 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False)
- content = yield self.initial_sync_handler.snapshot_all_rooms(
+ content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 25a1b67092..dceb2792fa 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,12 +14,6 @@
# limitations under the License.
import logging
-import xml.etree.ElementTree as ET
-
-from six.moves import urllib
-
-from twisted.internet import defer
-from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -29,9 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@@ -88,29 +83,42 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
+ self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs)
- self._address_ratelimiter = Ratelimiter()
+ self._address_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_address.per_second,
+ burst_count=self.hs.config.rc_login_address.burst_count,
+ )
+ self._account_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_account.per_second,
+ burst_count=self.hs.config.rc_login_account.burst_count,
+ )
+ self._failed_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.cas_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
+ if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
# to SSO.
flows.append({"type": LoginRestServlet.CAS_TYPE})
+ if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
# While its valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
- # CAS login flow.
+ # SSO login flow.
# Generally we don't want to advertise login flows that clients
# don't know how to implement, since they (currently) will always
# fall back to the fallback API if they don't understand one of the
@@ -126,26 +134,19 @@ class LoginRestServlet(RestServlet):
def on_OPTIONS(self, request):
return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- self._address_ratelimiter.ratelimit(
- request.getClientIP(),
- time_now_s=self.hs.clock.time(),
- rate_hz=self.hs.config.rc_login_address.per_second,
- burst_count=self.hs.config.rc_login_address.burst_count,
- update=True,
- )
+ async def on_POST(self, request):
+ self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
try:
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
):
- result = yield self.do_jwt_login(login_submission)
+ result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = yield self.do_token_login(login_submission)
+ result = await self.do_token_login(login_submission)
else:
- result = yield self._do_other_login(login_submission)
+ result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -154,8 +155,7 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- @defer.inlineCallbacks
- def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
Args:
@@ -201,28 +201,43 @@ class LoginRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
+
# Check for login providers that support 3pid login types
- canonical_user_id, callback_3pid = (
- yield self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.auth_handler.check_password_provider_3pid(
+ medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
- result = yield self._register_device_with_callback(
+
+ result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
# No password providers were able to handle this 3pid
# Check local store
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
- logger.warn(
+ logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
@@ -234,32 +249,71 @@ class LoginRestServlet(RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- canonical_user_id, callback = yield self.auth_handler.validate_login(
- identifier["user"], login_submission
+ if identifier["user"].startswith("@"):
+ qualified_user_id = identifier["user"]
+ else:
+ qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ self._failed_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(), update=False
)
- result = yield self._register_device_with_callback(
+ try:
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ identifier["user"], login_submission
+ )
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
+ raise
+
+ result = await self._complete_login(
canonical_user_id, login_submission, callback
)
return result
- @defer.inlineCallbacks
- def _register_device_with_callback(self, user_id, login_submission, callback=None):
- """ Registers a device with a given user_id. Optionally run a callback
- function after registration has completed.
+ async def _complete_login(
+ self, user_id, login_submission, callback=None, create_non_existent_users=False
+ ):
+ """Called when we've successfully authed the user and now need to
+ actually login them in (e.g. create devices). This gets called on
+ all succesful logins.
+
+ Applies the ratelimiting for succesful login attempts against an
+ account.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
+ create_non_existent_users (bool): Whether to create the user if
+ they don't exist. Defaults to False.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
"""
+
+ # Before we actually log them in we check if they've already logged in
+ # too often. This happens here rather than before as we don't
+ # necessarily know the user before now.
+ self._account_ratelimiter.ratelimit(user_id.lower())
+
+ if create_non_existent_users:
+ canonical_uid = await self.auth_handler.check_user_exists(user_id)
+ if not canonical_uid:
+ canonical_uid = await self.registration_handler.register_user(
+ localpart=UserID.from_string(user_id).localpart
+ )
+ user_id = canonical_uid
+
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name
)
@@ -271,23 +325,21 @@ class LoginRestServlet(RestServlet):
}
if callback is not None:
- yield callback(result)
+ await callback(result)
return result
- @defer.inlineCallbacks
- def do_token_login(self, login_submission):
+ async def do_token_login(self, login_submission):
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = (
- yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
)
- result = yield self._register_device_with_callback(user_id, login_submission)
+ result = await self._complete_login(user_id, login_submission)
return result
- @defer.inlineCallbacks
- def do_jwt_login(self, login_submission):
+ async def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
@@ -311,15 +363,8 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string()
-
- registered_user_id = yield self.auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = yield self.registration_handler.register_user(
- localpart=user
- )
-
- result = yield self._register_device_with_callback(
- registered_user_id, login_submission
+ result = await self._complete_login(
+ user_id, login_submission, create_non_existent_users=True
)
return result
@@ -329,24 +374,27 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
- sso_url = self.get_sso_url(client_redirect_url)
+ sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url):
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
- client_redirect_url (bytes): the URL that we should redirect the
+ request: The client request to redirect.
+ client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
- bytes: URL to redirect to
+ URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@@ -354,19 +402,14 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
- super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url.encode("ascii")
- self.cas_service_url = hs.config.cas_service_url.encode("ascii")
+ self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url):
- client_redirect_url_param = urllib.parse.urlencode(
- {b"redirectUrl": client_redirect_url}
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return self._cas_handler.get_redirect_url(
+ {"redirectUrl": client_redirect_url}
).encode("ascii")
- hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
- service_param = urllib.parse.urlencode(
- {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
- ).encode("ascii")
- return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@@ -374,80 +417,25 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
- self.cas_required_attributes = hs.config.cas_required_attributes
- self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
-
- @defer.inlineCallbacks
- def on_GET(self, request):
- client_redirect_url = parse_string(request, "redirectUrl", required=True)
- uri = self.cas_server_url + "/proxyValidate"
- args = {
- "ticket": parse_string(request, "ticket", required=True),
- "service": self.cas_service_url,
- }
- try:
- body = yield self._http_client.get_raw(uri, args)
- except PartialDownloadError as pde:
- # Twisted raises this error if the connection is closed,
- # even if that's being used old-http style to signal end-of-data
- body = pde.response
- result = yield self.handle_cas_response(request, body, client_redirect_url)
- return result
+ self._cas_handler = hs.get_cas_handler()
- def handle_cas_response(self, request, cas_response_body, client_redirect_url):
- user, attributes = self.parse_cas_response(cas_response_body)
+ async def on_GET(self, request: SynapseRequest) -> None:
+ client_redirect_url = parse_string(request, "redirectUrl")
+ ticket = parse_string(request, "ticket", required=True)
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ # Maybe get a session ID (if this ticket is from user interactive
+ # authentication).
+ session = parse_string(request, "session")
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ # Either client_redirect_url or session must be provided.
+ if not client_redirect_url and not session:
+ message = "Missing string query parameter redirectUrl or session"
+ raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
- return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url
+ await self._cas_handler.handle_ticket(
+ request, ticket, client_redirect_url, session
)
- def parse_cas_response(self, cas_response_body):
- user = None
- attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.error("Error parsing CAS response", exc_info=1)
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
-
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -455,74 +443,26 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url):
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
-class SSOAuthHandler(object):
- """
- Utility class for Resources and Servlets which handle the response from a SSO
- service
+class OIDCRedirectServlet(BaseSSORedirectServlet):
+ """Implementation for /login/sso/redirect for the OIDC login flow."""
- Args:
- hs (synapse.server.HomeServer)
- """
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
- self._hostname = hs.hostname
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
- self._macaroon_gen = hs.get_macaroon_generator()
-
- @defer.inlineCallbacks
- def on_successful_auth(
- self, username, request, client_redirect_url, user_display_name=None
- ):
- """Called once the user has successfully authenticated with the SSO.
-
- Registers the user if necessary, and then returns a redirect (with
- a login token) to the client.
-
- Args:
- username (unicode|bytes): the remote user id. We'll map this onto
- something sane for a MXID localpath.
-
- request (SynapseRequest): the incoming request from the browser. We'll
- respond to it with a redirect.
-
- client_redirect_url (unicode): the redirect_url the client gave us when
- it first started the process.
-
- user_display_name (unicode|None): if set, and we have to register a new user,
- we will set their displayname to this.
+ self._oidc_handler = hs.get_oidc_handler()
- Returns:
- Deferred[none]: Completes once we have handled the request.
- """
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = yield self._auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = yield self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
- )
-
- login_token = self._macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url
)
- redirect_url = self._add_login_token_to_redirect_url(
- client_redirect_url, login_token
- )
- request.redirect(redirect_url)
- finish_request(request)
-
- @staticmethod
- def _add_login_token_to_redirect_url(url, token):
- url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.parse.urlencode(query)
- return urllib.parse.urlunparse(url_parts)
def register_servlets(hs, http_server):
@@ -532,3 +472,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
+ elif hs.config.oidc_enabled:
+ OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 4785a34d75..b0c30b65be 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet):
def on_OPTIONS(self, request):
return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
- # the acccess token wasn't associated with a device.
+ # The access token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
- yield self._auth_handler.delete_access_token(access_token)
+ await self._auth_handler.delete_access_token(access_token)
else:
- yield self._device_handler.delete_device(
+ await self._device_handler.delete_device(
requester.user.to_string(), requester.device_id
)
@@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet):
def on_OPTIONS(self, request):
return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
# first delete all of the user's devices
- yield self._device_handler.delete_all_devices_for_user(user_id)
+ await self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
+ await self._auth_handler.delete_access_tokens_for_user(user_id)
return 200, {}
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 0153525cef..eec16f8ad8 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -19,8 +19,6 @@ import logging
from six import string_types
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
- allowed = yield self.presence_handler.is_visible(
+ allowed = await self.presence_handler.is_visible(
observed_user=user, observer_user=requester.user
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
- state = yield self.presence_handler.get_state(target_user=user)
+ state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
return 200, state
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
@@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet):
raise SynapseError(400, "Unable to parse state")
if self.hs.config.use_presence:
- yield self.presence_handler.set_state(user, state)
+ await self.presence_handler.set_state(user, state)
return 200, {}
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index bbce2e2b71..e7fe50ed72 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,8 +14,8 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
-from twisted.internet import defer
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
@@ -30,19 +30,18 @@ class ProfileDisplaynameRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- displayname = yield self.profile_handler.get_displayname(user)
+ displayname = await self.profile_handler.get_displayname(user)
ret = {}
if displayname is not None:
@@ -50,11 +49,10 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
@@ -63,7 +61,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
except Exception:
return 400, "Unable to parse name"
- yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
return 200, {}
@@ -80,19 +78,18 @@ class ProfileAvatarURLRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- avatar_url = yield self.profile_handler.get_avatar_url(user)
+ avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {}
if avatar_url is not None:
@@ -100,19 +97,22 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
try:
- new_name = content["avatar_url"]
- except Exception:
- return 400, "Unable to parse name"
+ new_avatar_url = content["avatar_url"]
+ except KeyError:
+ raise SynapseError(
+ 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM
+ )
- yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+ await self.profile_handler.set_avatar_url(
+ user, requester, new_avatar_url, is_admin
+ )
return 200, {}
@@ -129,20 +129,19 @@ class ProfileRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- displayname = yield self.profile_handler.get_displayname(user)
- avatar_url = yield self.profile_handler.get_avatar_url(user)
+ displayname = await self.profile_handler.get_displayname(user)
+ avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {}
if displayname is not None:
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9f8c3d09e3..9fd4908136 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import (
NotFoundError,
@@ -46,18 +45,17 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
- @defer.inlineCallbacks
- def on_PUT(self, request, path):
+ async def on_PUT(self, request, path):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
- spec = _rule_spec_from_path([x for x in path.split("/")])
+ spec = _rule_spec_from_path(path.split("/"))
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
raise SynapseError(400, "rule_id may not contain slashes")
@@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string()
if "attr" in spec:
- yield self.set_rule_attr(user_id, spec, content)
+ await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
return 200, {}
@@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet):
after = _namespaced_rule_id(spec, after)
try:
- yield self.store.add_push_rule(
+ await self.store.add_push_rule(
user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
@@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, path):
+ async def on_DELETE(self, request, path):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
- spec = _rule_spec_from_path([x for x in path.split("/")])
+ spec = _rule_spec_from_path(path.split("/"))
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
- yield self.store.delete_push_rule(user_id, namespaced_rule_id)
+ await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
return 200, {}
except StoreError as e:
@@ -130,19 +127,18 @@ class PushRuleRestServlet(RestServlet):
else:
raise
- @defer.inlineCallbacks
- def on_GET(self, request, path):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, path):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
- rules = yield self.store.get_push_rules_for_user(user_id)
+ rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rules)
- path = [x for x in path.split("/")][1:]
+ path = path.split("/")[1:]
if path == []:
# we're a reference impl: pedantry is our job.
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 41660682d9..550a2f1b44 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
@@ -30,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
+ALLOWED_KEYS = {
+ "app_display_name",
+ "app_id",
+ "data",
+ "device_display_name",
+ "kind",
+ "lang",
+ "profile_tag",
+ "pushkey",
+}
+
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -39,30 +48,17 @@ class PushersRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
- pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
-
- allowed_keys = [
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
- ]
+ pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- for p in pushers:
- for k, v in list(p.items()):
- if k not in allowed_keys:
- del p[k]
+ filtered_pushers = [
+ {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
+ ]
- return 200, {"pushers": pushers}
+ return 200, {"pushers": filtered_pushers}
def on_OPTIONS(self, _):
return 200, {}
@@ -78,9 +74,8 @@ class PushersSetRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
content = parse_json_object_from_request(request)
@@ -91,7 +86,7 @@ class PushersSetRestServlet(RestServlet):
and "kind" in content
and content["kind"] is None
):
- yield self.pusher_pool.remove_pusher(
+ await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
)
return 200, {}
@@ -117,14 +112,14 @@ class PushersSetRestServlet(RestServlet):
append = content["append"]
if not append:
- yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
pushkey=content["pushkey"],
not_user_id=user.to_string(),
)
try:
- yield self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content["kind"],
@@ -164,16 +159,15 @@ class PushersRemoveRestServlet(RestServlet):
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
try:
- yield self.pusher_pool.remove_pusher(
+ await self.pusher_pool.remove_pusher(
app_id=app_id, pushkey=pushkey, user_id=user.to_string()
)
except StoreError as se:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index a6a7b3b57e..105e0cf4d2 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -16,17 +16,18 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
+import re
+from typing import List, Optional
from six.moves.urllib import parse as urlparse
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
+ HttpResponseException,
InvalidClientCredentialsError,
SynapseError,
)
@@ -39,12 +40,17 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+MYPY = False
+if MYPY:
+ import synapse.server
+
logger = logging.getLogger(__name__)
@@ -81,13 +87,13 @@ class RoomCreateRestServlet(TransactionRestServlet):
)
def on_PUT(self, request, txn_id):
+ set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
- info = yield self._room_creation_handler.create_room(
+ info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -152,15 +158,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "")
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_type, state_key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_type, state_key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
)
msg_handler = self.message_handler
- data = yield msg_handler.get_room_data(
+ data = await msg_handler.get_room_data(
user_id=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
@@ -177,9 +182,11 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
+
+ if txn_id:
+ set_tag("txn_id", txn_id)
content = parse_json_object_from_request(request)
@@ -195,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.room_member_handler.update_membership(
+ event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -203,13 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
+ event_id = event.event_id
- ret = {}
- if event:
- ret = {"event_id": event.event_id}
+ ret = {} # type: dict
+ if event_id:
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -225,9 +237,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_type, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, event_type, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict = {
@@ -240,16 +251,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
+ set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
)
@@ -267,9 +281,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_identifier, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_identifier, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
content = parse_json_object_from_request(request)
@@ -283,20 +296,20 @@ class JoinRoomAliasServlet(TransactionRestServlet):
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
- ]
+ ] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -310,6 +323,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
@@ -324,12 +339,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
server = parse_string(request, "server", default=None)
try:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ await self.auth.get_user_by_req(request, allow_guest=True)
except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
@@ -350,26 +364,32 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
handler = self.hs.get_room_list_handler()
- if server:
- data = yield handler.get_remote_public_room_list(
- server, limit=limit, since_token=since_token
- )
+ if server and server != self.hs.config.server_name:
+ try:
+ data = await handler.get_remote_public_room_list(
+ server, limit=limit, since_token=since_token
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token
)
return 200, data
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
- limit = int(content.get("limit", 100))
+ limit = int(content.get("limit", 100)) # type: Optional[int]
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@@ -387,18 +407,25 @@ class PublicRoomListRestServlet(TransactionRestServlet):
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
handler = self.hs.get_room_list_handler()
- if server:
- data = yield handler.get_remote_public_room_list(
- server,
- limit=limit,
- since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
+ if server and server != self.hs.config.server_name:
+ try:
+ data = await handler.get_remote_public_room_list(
+ server,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
@@ -417,10 +444,9 @@ class RoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
+ async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -440,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
- events = yield handler.get_state_events(
+ events = await handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
@@ -469,11 +495,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- users_with_profile = yield self.message_handler.get_joined_members(
+ users_with_profile = await self.message_handler.get_joined_members(
requester, room_id
)
@@ -489,20 +514,24 @@ class RoomMessageListRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
- event_filter = Filter(json.loads(filter_json))
- if event_filter.filter_json.get("event_format", "client") == "federation":
+ event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ if (
+ event_filter
+ and event_filter.filter_json.get("event_format", "client")
+ == "federation"
+ ):
as_client_event = False
else:
event_filter = None
- msgs = yield self.pagination_handler.get_messages(
+
+ msgs = await self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -522,11 +551,10 @@ class RoomStateRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
- events = yield self.message_handler.get_state_events(
+ events = await self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
@@ -543,11 +571,10 @@ class RoomInitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
- content = yield self.initial_sync_handler.room_initial_sync(
+ content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
return 200, content
@@ -565,11 +592,10 @@ class RoomEventServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
- event = yield self.event_handler.get_event(
+ event = await self.event_handler.get_event(
requester.user, room_id, event_id
)
except AuthError:
@@ -580,7 +606,7 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
+ event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -598,9 +624,8 @@ class RoomEventContextServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -608,11 +633,11 @@ class RoomEventContextServlet(RestServlet):
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes)
- event_filter = Filter(json.loads(filter_json))
+ event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None
- results = yield self.room_context_handler.get_event_context(
+ results = await self.room_context_handler.get_event_context(
requester.user, room_id, event_id, limit, event_filter
)
@@ -620,16 +645,16 @@ class RoomEventContextServlet(RestServlet):
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- results["events_before"] = yield self._event_serializer.serialize_events(
+ results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"], time_now
)
- results["event"] = yield self._event_serializer.serialize_event(
+ results["event"] = await self._event_serializer.serialize_event(
results["event"], time_now
)
- results["events_after"] = yield self._event_serializer.serialize_events(
+ results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"], time_now
)
- results["state"] = yield self._event_serializer.serialize_events(
+ results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
)
@@ -646,15 +671,16 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ async def on_POST(self, request, room_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
- yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
+ await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
def on_PUT(self, request, room_id, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
)
@@ -675,9 +701,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, membership_action, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
@@ -693,7 +718,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.room_member_handler.do_3pid_invite(
+ await self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -711,10 +736,10 @@ class RoomMembershipRestServlet(TransactionRestServlet):
target = UserID.from_string(content["user_id"])
event_content = None
- if "reason" in content and membership_action in ["kick", "ban"]:
+ if "reason" in content:
event_content = {"reason": content["reason"]}
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -738,6 +763,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return True
def on_PUT(self, request, room_id, membership_action, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
)
@@ -754,12 +781,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -771,9 +797,12 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id,
)
+ set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
)
@@ -790,35 +819,57 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]:
- yield self.typing_handler.started_typing(
+ await self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=timeout,
)
else:
- yield self.typing_handler.stopped_typing(
+ await self.typing_handler.stopped_typing(
target_user=target_user, auth_user=requester.user, room_id=room_id
)
return 200, {}
+class RoomAliasListServlet(RestServlet):
+ PATTERNS = [
+ re.compile(
+ r"^/_matrix/client/unstable/org\.matrix\.msc2432"
+ r"/rooms/(?P<room_id>[^/]*)/aliases"
+ ),
+ ]
+
+ def __init__(self, hs: "synapse.server.HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.directory_handler = hs.get_handlers().directory_handler
+
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+
+ alias_list = await self.directory_handler.get_aliases_for_room(
+ requester, room_id
+ )
+
+ return 200, {"aliases": alias_list}
+
+
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
@@ -827,14 +878,13 @@ class SearchRestServlet(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
- results = yield self.handlers.search_handler.search(
+ results = await self.handlers.search_handler.search(
requester.user, content, batch
)
@@ -849,11 +899,10 @@ class JoinedRoomsRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
+ room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
@@ -909,6 +958,7 @@ def register_servlets(hs, http_server):
JoinedRoomsRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
+ RoomAliasListServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 2afdbb89e5..747d46eac2 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -17,8 +17,6 @@ import base64
import hashlib
import hmac
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 8250ae0ae1..bc11b4dda4 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -32,7 +32,7 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
Args:
path_regex (str): The regex string to match. This should NOT have a ^
- as this will be prefixed.
+ as this will be prefixed.
Returns:
SRE_Pattern
"""
@@ -78,7 +78,7 @@ def interactive_auth_handler(orig):
"""
def wrapped(*args, **kwargs):
- res = defer.maybeDeferred(orig, *args, **kwargs)
+ res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 2ea515d2f6..d4f721b6b9 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -18,8 +18,6 @@ import logging
from six.moves import http_client
-from twisted.internet import defer
-
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour
@@ -32,6 +30,7 @@ from synapse.http.servlet import (
)
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -67,11 +66,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=template_text,
)
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"User password resets have been disabled due to lack of email config"
)
raise SynapseError(
@@ -84,6 +82,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Extract params from body
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
email = body["email"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -95,25 +95,23 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
if existing_user_id is None:
+ if self.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- # Have the configured identity server handle the request
- if not self.hs.config.account_threepid_delegate_email:
- logger.warn(
- "No upstream email account_threepid_delegate configured on the server to "
- "handle this request"
- )
- raise SynapseError(
- 400, "Password reset by email is not supported on this homeserver"
- )
+ assert self.hs.config.account_threepid_delegate_email
- ret = yield self.identity_handler.requestEmailToken(
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
client_secret,
@@ -122,7 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
else:
# Send password reset emails from Synapse
- sid = yield self.identity_handler.send_threepid_validation(
+ sid = await self.identity_handler.send_threepid_validation(
email,
client_secret,
send_attempt,
@@ -136,71 +134,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret
-class MsisdnPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
-
- def __init__(self, hs):
- super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
- self.hs = hs
- self.datastore = self.hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- body = parse_json_object_from_request(request)
-
- assert_params_in_dict(
- body, ["client_secret", "country", "phone_number", "send_attempt"]
- )
- client_secret = body["client_secret"]
- country = body["country"]
- phone_number = body["phone_number"]
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
-
- msisdn = phone_number_to_msisdn(country, phone_number)
-
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
- raise SynapseError(
- 403,
- "Account phone numbers are not authorized on this server",
- Codes.THREEPID_DENIED,
- )
-
- existing_user_id = yield self.datastore.get_user_id_by_threepid(
- "msisdn", msisdn
- )
-
- if existing_user_id is None:
- raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
-
- if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
- "No upstream msisdn account_threepid_delegate configured on the server to "
- "handle this request"
- )
- raise SynapseError(
- 400,
- "Password reset by phone number is not supported on this homeserver",
- )
-
- ret = yield self.identity_handler.requestMsisdnToken(
- self.hs.config.account_threepid_delegate_msisdn,
- country,
- phone_number,
- client_secret,
- send_attempt,
- next_link,
- )
-
- return 200, ret
-
-
class PasswordResetSubmitTokenServlet(RestServlet):
"""Handles 3PID validation token submission"""
PATTERNS = client_patterns(
- "/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True
+ "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
def __init__(self, hs):
@@ -214,9 +152,13 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_password_reset_template_failure_html],
+ )
- @defer.inlineCallbacks
- def on_GET(self, request, medium):
+ async def on_GET(self, request, medium):
# We currently only handle threepid token submissions for email
if medium != "email":
raise SynapseError(
@@ -224,7 +166,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Password reset emails have been disabled due to lack of an email config"
)
raise SynapseError(
@@ -232,20 +174,21 @@ class PasswordResetSubmitTokenServlet(RestServlet):
)
sid = parse_string(request, "sid", required=True)
- client_secret = parse_string(request, "client_secret", required=True)
token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
# Attempt to validate a 3PID session
try:
# Mark the session as valid
- next_link = yield self.store.validate_threepid_session(
+ next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec()
)
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
@@ -261,34 +204,12 @@ class PasswordResetSubmitTokenServlet(RestServlet):
request.setResponseCode(e.code)
# Show a failure page with a reason
- html_template, = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_password_reset_template_failure_html],
- )
-
template_vars = {"failure_reason": e.msg}
- html = html_template.render(**template_vars)
+ html = self.failure_email_template.render(**template_vars)
request.write(html.encode("utf-8"))
finish_request(request)
- @defer.inlineCallbacks
- def on_POST(self, request, medium):
- if medium != "email":
- raise SynapseError(
- 400, "This medium is currently not supported for password resets"
- )
-
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["sid", "client_secret", "token"])
-
- valid, _ = yield self.store.validate_threepid_session(
- body["sid"], body["client_secret"], body["token"], self.clock.time_msec()
- )
- response_code = 200 if valid else 400
-
- return response_code, {"success": valid}
-
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
@@ -299,13 +220,27 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
+ self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
+ # we do basic sanity checks here because the auth layer will store these
+ # in sessions. Pull out the new password provided to us.
+ if "new_password" in body:
+ new_password = body.pop("new_password")
+ if not isinstance(new_password, str) or len(new_password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(new_password)
+
+ # If the password is valid, hash it and store it back on the body.
+ # This ensures that only the hashed password is handled everywhere.
+ if "new_password_hash" in body:
+ raise SynapseError(400, "Unexpected property: new_password_hash")
+ body["new_password_hash"] = await self.auth_handler.hash(new_password)
+
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -317,17 +252,23 @@ class PasswordRestServlet(RestServlet):
# In the second case, we require a password to confirm their identity.
if self.auth.has_access_token(request):
- requester = yield self.auth.get_user_by_req(request)
- params = yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ requester = await self.auth.get_user_by_req(request)
+ params = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
)
user_id = requester.user.to_string()
else:
requester = None
- result, params, _ = yield self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+ result, params, _ = await self.auth_handler.check_auth(
+ [[LoginType.EMAIL_IDENTITY]],
+ request,
body,
self.hs.get_ip_from_request(request),
+ "modify your account password",
)
if LoginType.EMAIL_IDENTITY in result:
@@ -340,7 +281,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py)
threepid["address"] = threepid["address"].lower()
# if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+ threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"]
)
if not threepid_user_id:
@@ -350,10 +291,13 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password"])
- new_password = params["new_password"]
+ assert_params_in_dict(params, ["new_password_hash"])
+ new_password_hash = params["new_password_hash"]
+ logout_devices = params.get("logout_devices", True)
- yield self._set_password_handler.set_password(user_id, new_password, requester)
+ await self._set_password_handler.set_password(
+ user_id, new_password_hash, logout_devices, requester
+ )
return 200, {}
@@ -372,8 +316,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -383,19 +326,23 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users
if requester.app_service:
- yield self._deactivate_account_handler.deactivate_account(
+ await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
)
return 200, {}
- yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "deactivate your account",
)
- result = yield self._deactivate_account_handler.deactivate_account(
+ result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
)
if result:
@@ -416,14 +363,37 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ template_html, template_text = load_jinja2_templates(
+ self.config.email_template_dir,
+ [
+ self.config.email_add_threepid_template_html,
+ self.config.email_add_threepid_template_text,
+ ],
+ public_baseurl=self.config.public_baseurl,
+ )
+ self.mailer = Mailer(
+ hs=self.hs,
+ app_name=self.config.email_app_name,
+ template_html=template_html,
+ template_text=template_text,
+ )
+
+ async def on_POST(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Adding an email to your account is disabled on this server"
+ )
+
body = parse_json_object_from_request(request)
- assert_params_in_dict(
- body, ["id_server", "client_secret", "email", "send_attempt"]
- )
- id_server = "https://" + body["id_server"] # Assume https
+ assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
email = body["email"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -435,16 +405,42 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.store.get_user_id_by_threepid(
+ existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"]
)
if existing_user_id is not None:
+ if self.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- ret = yield self.identity_handler.requestEmailToken(
- id_server, email, client_secret, send_attempt, next_link
- )
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.requestEmailToken(
+ self.hs.config.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send threepid validation emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_add_threepid_mail,
+ next_link,
+ )
+
+ # Wrap the session id in a JSON object
+ ret = {"sid": sid}
+
return 200, ret
@@ -457,15 +453,14 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
- body,
- ["id_server", "client_secret", "country", "phone_number", "send_attempt"],
+ body, ["client_secret", "country", "phone_number", "send_attempt"]
)
- id_server = "https://" + body["id_server"] # Assume https
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
@@ -480,17 +475,156 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn)
+ existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
+ if self.hs.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
- ret = yield self.identity_handler.requestMsisdnToken(
- id_server, country, phone_number, client_secret, send_attempt, next_link
+ if not self.hs.config.account_threepid_delegate_msisdn:
+ logger.warning(
+ "No upstream msisdn account_threepid_delegate configured on the server to "
+ "handle this request"
+ )
+ raise SynapseError(
+ 400,
+ "Adding phone numbers to user account is not supported by this homeserver",
+ )
+
+ ret = await self.identity_handler.requestMsisdnToken(
+ self.hs.config.account_threepid_delegate_msisdn,
+ country,
+ phone_number,
+ client_secret,
+ send_attempt,
+ next_link,
)
+
return 200, ret
+class AddThreepidEmailSubmitTokenServlet(RestServlet):
+ """Handles 3PID validation token submission for adding an email to a user's account"""
+
+ PATTERNS = client_patterns(
+ "/add_threepid/email/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_add_threepid_template_failure_html],
+ )
+
+ async def on_GET(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Adding an email to your account is disabled on this server"
+ )
+ elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating threepids. Use an identity server "
+ "instead.",
+ )
+
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warning(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
+
+ # Otherwise show the success template
+ html = self.config.email_add_threepid_template_success_html_content
+ request.setResponseCode(200)
+ except ThreepidValidationError as e:
+ request.setResponseCode(e.code)
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html = self.failure_email_template.render(**template_vars)
+
+ request.write(html.encode("utf-8"))
+ finish_request(request)
+
+
+class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
+ """Handles 3PID validation token submission for adding a phone number to a user's
+ account
+ """
+
+ PATTERNS = client_patterns(
+ "/add_threepid/msisdn/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ async def on_POST(self, request):
+ if not self.config.account_threepid_delegate_msisdn:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating phone numbers. Use an identity server "
+ "instead.",
+ )
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["client_secret", "sid", "token"])
+ assert_valid_client_secret(body["client_secret"])
+
+ # Proxy submit_token request to msisdn threepid delegate
+ response = await self.identity_handler.proxy_msisdn_submit_token(
+ self.config.account_threepid_delegate_msisdn,
+ body["client_secret"],
+ body["sid"],
+ body["token"],
+ )
+ return 200, response
+
+
class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
@@ -502,16 +636,21 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
- threepids = yield self.datastore.user_get_threepids(requester.user.to_string())
+ threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids}
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
@@ -519,34 +658,111 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(
400, "Missing param three_pid_creds", Codes.MISSING_PARAM
)
+ assert_params_in_dict(threepid_creds, ["client_secret", "sid"])
+
+ sid = threepid_creds["sid"]
+ client_secret = threepid_creds["client_secret"]
+ assert_valid_client_secret(client_secret)
+
+ validation_session = await self.identity_handler.validate_threepid_session(
+ client_secret, sid
+ )
+ if validation_session:
+ await self.auth_handler.add_threepid(
+ user_id,
+ validation_session["medium"],
+ validation_session["address"],
+ validation_session["validated_at"],
+ )
+ return 200, {}
+
+ raise SynapseError(
+ 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
+ )
+
+
+class ThreepidAddRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidAddRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ @interactive_auth_handler
+ async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
- # Specify None as the identity server to retrieve it from the request body instead
- threepid = yield self.identity_handler.threepid_from_creds(None, threepid_creds)
+ assert_params_in_dict(body, ["client_secret", "sid"])
+ sid = body["sid"]
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
- if not threepid:
- raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "add a third-party identifier to your account",
+ )
- for reqd in ["medium", "address", "validated_at"]:
- if reqd not in threepid:
- logger.warn("Couldn't add 3pid: invalid response from ID server")
- raise SynapseError(500, "Invalid response from ID Server")
+ validation_session = await self.identity_handler.validate_threepid_session(
+ client_secret, sid
+ )
+ if validation_session:
+ await self.auth_handler.add_threepid(
+ user_id,
+ validation_session["medium"],
+ validation_session["address"],
+ validation_session["validated_at"],
+ )
+ return 200, {}
- yield self.auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ raise SynapseError(
+ 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
)
- if "bind" in body and body["bind"]:
- logger.debug("Binding threepid %s to %s", threepid, user_id)
- yield self.identity_handler.bind_threepid(threepid_creds, user_id)
+
+class ThreepidBindRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidBindRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+ id_server = body["id_server"]
+ sid = body["sid"]
+ id_access_token = body.get("id_access_token") # optional
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ await self.identity_handler.bind_threepid(
+ client_secret, sid, user_id, id_server, id_access_token
+ )
return 200, {}
class ThreepidUnbindRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/3pid/unbind$")
+ PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True)
def __init__(self, hs):
super(ThreepidUnbindRestServlet, self).__init__()
@@ -555,12 +771,11 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
@@ -570,7 +785,7 @@ class ThreepidUnbindRestServlet(RestServlet):
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
- result = yield self.identity_handler.try_unbind_threepid(
+ result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server},
)
@@ -582,19 +797,24 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
+ self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
+ if not self.hs.config.enable_3pid_changes:
+ raise SynapseError(
+ 400, "3PID changes are disabled on this server", Codes.FORBIDDEN
+ )
+
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
- ret = yield self.auth_handler.delete_threepid(
+ ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server")
)
except Exception:
@@ -619,22 +839,24 @@ class WhoamiRestServlet(RestServlet):
super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
return 200, {"user_id": requester.user.to_string()}
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
+ AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
+ AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
+ ThreepidAddRestServlet(hs).register(http_server)
+ ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index f0db204ffa..c1d4cd0caf 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -40,16 +38,19 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._is_worker = hs.config.worker_app is not None
+
+ async def on_PUT(self, request, user_id, account_data_type):
+ if self._is_worker:
+ raise Exception("Cannot handle PUT /account_data on worker")
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
- max_id = yield self.store.add_account_data_for_user(
+ max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
@@ -57,13 +58,12 @@ class AccountDataServlet(RestServlet):
return 200, {}
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = yield self.store.get_global_account_data_by_type_for_user(
+ event = await self.store.get_global_account_data_by_type_for_user(
account_data_type, user_id
)
@@ -90,10 +90,13 @@ class RoomAccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._is_worker = hs.config.worker_app is not None
+
+ async def on_PUT(self, request, user_id, room_id, account_data_type):
+ if self._is_worker:
+ raise Exception("Cannot handle PUT /account_data on worker")
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, room_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -106,7 +109,7 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
- max_id = yield self.store.add_account_data_to_room(
+ max_id = await self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
@@ -114,13 +117,12 @@ class RoomAccountDataServlet(RestServlet):
return 200, {}
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, room_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, room_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = yield self.store.get_account_data_for_room_and_type(
+ event = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type
)
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 33f6a23028..2f10fa64e2 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
@@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet):
self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- token_valid = yield self.account_activity_handler.renew_account(
+ token_valid = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
@@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8"))
finish_request(request)
- defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet):
@@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet):
self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)
- requester = yield self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
- yield self.account_activity_handler.send_renewal_email_to_user(user_id)
+ await self.account_activity_handler.send_renewal_email_to_user(user_id)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index f21aff39e5..75590ebaeb 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -132,7 +130,22 @@ class AuthRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- def on_GET(self, request, stagetype):
+ # SSO configuration.
+ self._cas_enabled = hs.config.cas_enabled
+ if self._cas_enabled:
+ self._cas_handler = hs.get_cas_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
+ self._saml_enabled = hs.config.saml2_enabled
+ if self._saml_enabled:
+ self._saml_handler = hs.get_saml_handler()
+ self._oidc_enabled = hs.config.oidc_enabled
+ if self._oidc_enabled:
+ self._oidc_handler = hs.get_oidc_handler()
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
+
+ async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@@ -144,14 +157,6 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
"session": session,
@@ -160,19 +165,51 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
+
+ elif stagetype == LoginType.SSO:
+ # Display a confirmation page which prompts the user to
+ # re-authenticate with their SSO provider.
+ if self._cas_enabled:
+ # Generate a request to CAS that redirects back to an endpoint
+ # to verify the successful authentication.
+ sso_redirect_url = self._cas_handler.get_redirect_url(
+ {"session": session},
+ )
+
+ elif self._saml_enabled:
+ # Some SAML identity providers (e.g. Google) require a
+ # RelayState parameter on requests. It is not necessary here, so
+ # pass in a dummy redirect URL (which will never get used).
+ client_redirect_url = b"unused"
+ sso_redirect_url = self._saml_handler.handle_redirect_request(
+ client_redirect_url, session
+ )
+
+ elif self._oidc_enabled:
+ client_redirect_url = b""
+ sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url, session
+ )
+
+ else:
+ raise SynapseError(400, "Homeserver not configured for SSO.")
+
+ html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+
else:
raise SynapseError(404, "Unknown auth stage type")
- @defer.inlineCallbacks
- def on_POST(self, request, stagetype):
+ # Render the HTML and return.
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+ return None
+
+ async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
if not session:
@@ -186,7 +223,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
- success = yield self.auth_handler.add_oob_auth(
+ success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
)
@@ -199,23 +236,10 @@ class AuthRestServlet(RestServlet):
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
"sitekey": self.hs.config.recaptcha_public_key,
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
-
- return None
elif stagetype == LoginType.TERMS:
- if ("session" not in request.args or len(request.args["session"])) == 0:
- raise SynapseError(400, "No session supplied")
-
- session = request.args["session"][0]
authdict = {"session": session}
- success = yield self.auth_handler.add_oob_auth(
+ success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
)
@@ -232,17 +256,22 @@ class AuthRestServlet(RestServlet):
"myurl": "%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
}
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
- return None
+ elif stagetype == LoginType.SSO:
+ # The SSO fallback workflow should not post here,
+ raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
else:
raise SynapseError(404, "Unknown auth stage type")
+ # Render the HTML and return.
+ html_bytes = html.encode("utf8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+ return None
+
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index acd58af193..fe9d019c44 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
@@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- user = yield self.store.get_user_by_id(requester.user.to_string())
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = await self.store.get_user_by_id(requester.user.to_string())
change_password = bool(user["password_hash"])
response = {
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 26d0235208..c0714fcfb1 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api import errors
from synapse.http.servlet import (
RestServlet,
@@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- devices = yield self.device_handler.get_devices_by_user(
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
)
return 200, {"devices": devices}
@@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
try:
body = parse_json_object_from_request(request)
@@ -84,11 +80,15 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
- yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "remove device(s) from your account",
)
- yield self.device_handler.delete_devices(
+ await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"]
)
return 200, {}
@@ -108,18 +108,16 @@ class DeviceRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- device = yield self.device_handler.get_device(
+ async def on_GET(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
return 200, device
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_DELETE(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request)
try:
body = parse_json_object_from_request(request)
@@ -132,19 +130,22 @@ class DeviceRestServlet(RestServlet):
else:
raise
- yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "remove a device from your account",
)
- yield self.device_handler.delete_device(requester.user.to_string(), device_id)
+ await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}
- @defer.inlineCallbacks
- def on_PUT(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_PUT(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
- yield self.device_handler.update_device(
+ await self.device_handler.update_device(
requester.user.to_string(), device_id, body
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index c6ddf24c8d..b28da017cd 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -15,9 +15,7 @@
import logging
-from twisted.internet import defer
-
-from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, filter_id):
+ async def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users")
@@ -52,13 +49,15 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
- filter = yield self.filtering.get_user_filter(
+ filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
)
+ except StoreError as e:
+ if e.code != 404:
+ raise
+ raise NotFoundError("No such filter")
- return 200, filter.get_filter_json()
- except (KeyError, StoreError):
- raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
+ return 200, filter_collection.get_filter_json()
class CreateFilterRestServlet(RestServlet):
@@ -70,11 +69,10 @@ class CreateFilterRestServlet(RestServlet):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
+ async def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users")
@@ -85,7 +83,7 @@ class CreateFilterRestServlet(RestServlet):
content = parse_json_object_from_request(request)
set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
- filter_id = yield self.filtering.add_user_filter(
+ filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content
)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 999a0fa80c..d84a6d7e11 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -38,24 +36,22 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- group_description = yield self.groups_handler.get_group_profile(
+ group_description = await self.groups_handler.get_group_profile(
group_id, requester_user_id
)
return 200, group_description
- @defer.inlineCallbacks
- def on_POST(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- yield self.groups_handler.update_group_profile(
+ await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- get_group_summary = yield self.groups_handler.get_group_summary(
+ get_group_summary = await self.groups_handler.get_group_summary(
group_id, requester_user_id
)
@@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, category_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, category_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_summary_room(
+ resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
room_id=room_id,
@@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, category_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_summary_room(
+ resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_category(
+ category = await self.groups_handler.get_group_category(
group_id, requester_user_id, category_id=category_id
)
return 200, category
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_category(
+ resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_category(
+ resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_categories(
+ category = await self.groups_handler.get_group_categories(
group_id, requester_user_id
)
@@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_role(
+ category = await self.groups_handler.get_group_role(
group_id, requester_user_id, role_id=role_id
)
return 200, category
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_role(
+ resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_role(
+ resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_roles(
+ category = await self.groups_handler.get_group_roles(
group_id, requester_user_id
)
@@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, role_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, role_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_summary_user(
+ resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
user_id=user_id,
@@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, role_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_summary_user(
+ resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_rooms_in_group(
+ result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_users_in_group(
+ result = await self.groups_handler.get_users_in_group(
group_id, requester_user_id
)
@@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_invited_users_in_group(
+ result = await self.groups_handler.get_invited_users_in_group(
group_id, requester_user_id
)
@@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.set_group_join_policy(
+ result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
# TODO: Create group on remote server
@@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
- result = yield self.groups_handler.create_group(
+ result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.add_room_to_group(
+ result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
return 200, result
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.remove_room_from_group(
+ result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, room_id, config_key):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, room_id, config_key):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.update_room_in_group(
+ result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
- result = yield self.groups_handler.invite(
+ result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.remove_user_from_group(
+ result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.remove_user_from_group(
+ result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.join_group(
+ result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.accept_invite(
+ result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
publicise = content["publicise"]
- yield self.store.update_group_publicity(group_id, requester_user_id, publicise)
+ await self.store.update_group_publicity(group_id, requester_user_id, publicise)
return 200, {}
@@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, user_id):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- result = yield self.groups_handler.get_publicised_groups_for_user(user_id)
+ result = await self.groups_handler.get_publicised_groups_for_user(user_id)
return 200, result
@@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
- result = yield self.groups_handler.bulk_get_publicised_groups(user_ids)
+ result = await self.groups_handler.bulk_get_publicised_groups(user_ids)
return 200, result
@@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_joined_groups(requester_user_id)
+ result = await self.groups_handler.get_joined_groups(requester_user_id)
return 200, result
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 2e680134a0..24bb090822 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -27,7 +26,7 @@ from synapse.http.servlet import (
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.types import StreamToken
-from ._base import client_patterns
+from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -43,7 +42,7 @@ class KeyUploadServlet(RestServlet):
"device_id": "<device_id>",
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [
- "m.olm.curve25519-aes-sha256",
+ "m.olm.curve25519-aes-sha2",
]
"keys": {
"<algorithm>:<device_id>": "<key_base64>",
@@ -70,9 +69,8 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@trace(opname="upload_keys")
- @defer.inlineCallbacks
- def on_POST(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -102,7 +100,7 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
- result = yield self.e2e_keys_handler.upload_keys_for_user(
+ result = await self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
)
return 200, result
@@ -126,7 +124,7 @@ class KeyQueryServlet(RestServlet):
"device_id": "<device_id>", // Duplicated to be signed
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [ // List of supported algorithms
- "m.olm.curve25519-aes-sha256",
+ "m.olm.curve25519-aes-sha2",
],
"keys": { // Must include a ed25519 signing key
"<algorithm>:<key_id>": "<key_base64>",
@@ -153,12 +151,12 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.query_devices(body, timeout)
+ result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
return 200, result
@@ -183,9 +181,8 @@ class KeyChangesServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
set_tag("from", from_token_string)
@@ -198,7 +195,7 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string()
- results = yield self.device_handler.get_user_ids_changed(user_id, from_token)
+ results = await self.device_handler.get_user_ids_changed(user_id, from_token)
return 200, results
@@ -229,12 +226,100 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+ result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+ return 200, result
+
+
+class SigningKeyUploadServlet(RestServlet):
+ """
+ POST /keys/device_signing/upload HTTP/1.1
+ Content-Type: application/json
+
+ {
+ }
+ """
+
+ PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(SigningKeyUploadServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ self.auth_handler = hs.get_auth_handler()
+
+ @interactive_auth_handler
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "add a device signing key to your account",
+ )
+
+ result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
+ return 200, result
+
+
+class SignaturesUploadServlet(RestServlet):
+ """
+ POST /keys/signatures/upload HTTP/1.1
+ Content-Type: application/json
+
+ {
+ "@alice:example.com": {
+ "<device_id>": {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ "m.megolm.v1.aes-sha2"
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "<signing_user_id>": {
+ "<algorithm>:<signing_key_base64>": "<signature_base64>>"
+ }
+ }
+ }
+ }
+ }
+ """
+
+ PATTERNS = client_patterns("/keys/signatures/upload$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(SignaturesUploadServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ result = await self.e2e_keys_handler.upload_signatures_for_device_keys(
+ user_id, body
+ )
return 200, result
@@ -243,3 +328,5 @@ def register_servlets(hs, http_server):
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
+ SigningKeyUploadServlet(hs).register(http_server)
+ SignaturesUploadServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 10c1ad5b07..aa911d75ee 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet):
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False)
@@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet):
limit = min(limit, 500)
- push_actions = yield self.store.get_push_actions_for_user(
+ push_actions = await self.store.get_push_actions_for_user(
user_id, from_token, limit, only_highlight=(only == "highlight")
)
- receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
+ receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read"
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
- notif_events = yield self.store.get_events(notif_event_ids)
+ notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = []
@@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet):
"actions": pa["actions"],
"ts": pa["received_ts"],
"event": (
- yield self._event_serializer.serialize_event(
+ await self._event_serializer.serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index b4925c0f59..6ae9a5a8e9 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
@@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet):
token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
- yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
+ await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
return (
200,
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
new file mode 100644
index 0000000000..968403cca4
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import RestServlet
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyServlet(RestServlet):
+ PATTERNS = client_patterns("/password_policy$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(PasswordPolicyServlet, self).__init__()
+
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ def on_GET(self, request):
+ if not self.enabled or not self.policy:
+ return (200, {})
+
+ policy = {}
+
+ for param in [
+ "minimum_length",
+ "require_digit",
+ "require_symbol",
+ "require_lowercase",
+ "require_uppercase",
+ ]:
+ if param in self.policy:
+ policy["m.%s" % param] = self.policy[param]
+
+ return (200, policy)
+
+
+def register_servlets(hs, http_server):
+ PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index b3bf8567e1..67cbc37312 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@@ -34,17 +32,16 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
if read_event_id:
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
user_id=requester.user.to_string(),
@@ -53,7 +50,7 @@ class ReadMarkerRestServlet(RestServlet):
read_marker_event_id = body.get("m.fully_read", None)
if read_marker_event_id:
- yield self.read_marker_handler.received_client_read_marker(
+ await self.read_marker_handler.received_client_read_marker(
room_id,
user_id=requester.user.to_string(),
event_id=read_marker_event_id,
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 0dab03d227..92555bd4a9 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
@@ -39,16 +37,15 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, receipt_type, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, receipt_type, event_id):
+ requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5c7a5f3579..b9ffe86b2a 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -16,24 +16,28 @@
import hmac
import logging
+from typing import List, Union
from six import string_types
-from twisted.internet import defer
-
import synapse
+import synapse.api.auth
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
- LimitExceededError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.config import ConfigError
+from synapse.config.captcha import CaptchaConfig
+from synapse.config.consent_config import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
+from synapse.handlers.auth import AuthHandler
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -44,6 +48,7 @@ from synapse.http.servlet import (
from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -96,11 +101,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=template_text,
)
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"Email registration has been disabled due to lack of email config"
)
raise SynapseError(
@@ -112,6 +116,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# Extract params from body
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
email = body["email"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -123,24 +129,23 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"]
)
if existing_user_id is not None:
+ if self.hs.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- if not self.hs.config.account_threepid_delegate_email:
- logger.warn(
- "No upstream email account_threepid_delegate configured on the server to "
- "handle this request"
- )
- raise SynapseError(
- 400, "Registration by email is not supported on this homeserver"
- )
+ assert self.hs.config.account_threepid_delegate_email
- ret = yield self.identity_handler.requestEmailToken(
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
client_secret,
@@ -149,7 +154,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
else:
# Send registration emails from Synapse
- sid = yield self.identity_handler.send_threepid_validation(
+ sid = await self.identity_handler.send_threepid_validation(
email,
client_secret,
send_attempt,
@@ -175,8 +180,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
@@ -197,17 +201,22 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
if existing_user_id is not None:
+ if self.hs.config.request_token_inhibit_3pid_errors:
+ # Make the client think the operation succeeded. See the rationale in the
+ # comments for request_token_inhibit_3pid_errors.
+ return 200, {"sid": random_string(16)}
+
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
if not self.hs.config.account_threepid_delegate_msisdn:
- logger.warn(
+ logger.warning(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
@@ -215,7 +224,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
400, "Registration by phone number is not supported on this homeserver"
)
- ret = yield self.identity_handler.requestMsisdnToken(
+ ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn,
country,
phone_number,
@@ -246,15 +255,26 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request, medium):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_registration_template_failure_html],
+ )
+
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_registration_template_failure_html],
+ )
+
+ async def on_GET(self, request, medium):
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warn(
+ logger.warning(
"User registration via email has been disabled due to lack of email config"
)
raise SynapseError(
@@ -268,14 +288,14 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session
try:
# Mark the session as valid
- next_link = yield self.store.validate_threepid_session(
+ next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec()
)
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
@@ -289,17 +309,11 @@ class RegistrationSubmitTokenServlet(RestServlet):
request.setResponseCode(200)
except ThreepidValidationError as e:
- # Show a failure page with a reason
request.setResponseCode(e.code)
# Show a failure page with a reason
- html_template, = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
- )
-
template_vars = {"failure_reason": e.msg}
- html = html_template.render(**template_vars)
+ html = self.failure_email_template.render(**template_vars)
request.write(html.encode("utf-8"))
finish_request(request)
@@ -332,15 +346,19 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
- yield wait_deferred
+ await wait_deferred
username = parse_string(request, "username", required=True)
- yield self.registration_handler.check_username(username)
+ await self.registration_handler.check_username(username)
return 200, {"available": True}
@@ -364,50 +382,46 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
+ self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_flows = _calculate_registration_flows(
+ hs.config, self.auth_handler
+ )
+
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
- time_now = self.clock.time()
-
- allowed, time_allowed = self.ratelimiter.can_do_action(
- client_addr,
- time_now_s=time_now,
- rate_hz=self.hs.config.rc_registration.per_second,
- burst_count=self.hs.config.rc_registration.burst_count,
- update=False,
- )
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now))
- )
+ self.ratelimiter.ratelimit(client_addr, update=False)
kind = b"user"
if b"kind" in request.args:
kind = request.args[b"kind"][0]
if kind == b"guest":
- ret = yield self._do_guest_registration(body, address=client_addr)
+ ret = await self._do_guest_registration(body, address=client_addr)
return ret
elif kind != b"user":
raise UnrecognizedRequestError(
- "Do not understand membership kind: %s" % (kind,)
+ "Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
if "password" in body:
- if (
- not isinstance(body["password"], string_types)
- or len(body["password"]) > 512
- ):
+ password = body.pop("password")
+ if not isinstance(password, string_types) or len(password) > 512:
raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(password)
+
+ # If the password is valid, hash it and store it back on the body.
+ # This ensures that only the hashed password is handled everywhere.
+ if "password_hash" in body:
+ raise SynapseError(400, "Unexpected property: password_hash")
+ body["password_hash"] = await self.auth_handler.hash(password)
desired_username = None
if "username" in body:
@@ -420,7 +434,7 @@ class RegisterRestServlet(RestServlet):
appservice = None
if self.auth.has_access_token(request):
- appservice = yield self.auth.get_appservice_by_req(request)
+ appservice = await self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -440,7 +454,7 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
- result = yield self._do_appservice_registration(
+ result = await self._do_appservice_registration(
desired_username, access_token, body
)
return 200, result # we throw for non 200 responses
@@ -460,12 +474,12 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password" not in body:
+ if "initial_device_display_name" in body and "password_hash" not in body:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
# the original registration params
- logger.warn("Ignoring initial_device_display_name without password")
+ logger.warning("Ignoring initial_device_display_name without password")
del body["initial_device_display_name"]
session_id = self.auth_handler.get_session_id(body)
@@ -475,80 +489,23 @@ class RegisterRestServlet(RestServlet):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
- registered_user_id = self.auth_handler.get_session_data(
+ registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if desired_username is not None:
- yield self.registration_handler.check_username(
+ await self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
)
- # FIXME: need a better error than "no auth flow found" for scenarios
- # where we required 3PID for registration but the user didn't give one
- require_email = "email" in self.hs.config.registrations_require_3pid
- require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid
-
- show_msisdn = True
- if self.hs.config.disable_msisdn_registration:
- show_msisdn = False
- require_msisdn = False
-
- flows = []
- if self.hs.config.enable_registration_captcha:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- # Also add a dummy flow here, otherwise if a client completes
- # recaptcha first we'll assume they were going for this flow
- # and complete the request, when they could have been trying to
- # complete one of the flows with email/msisdn auth.
- flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email:
- flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend(
- [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]
- )
- else:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- flows.extend([[LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email or require_msisdn:
- flows.extend([[LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]])
-
- # Append m.login.terms to all flows if we're requiring consent
- if self.hs.config.user_consent_at_registration:
- new_flows = []
- for flow in flows:
- inserted = False
- # m.login.terms should go near the end but before msisdn or email auth
- for i, stage in enumerate(flow):
- if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
- flow.insert(i, LoginType.TERMS)
- inserted = True
- break
- if not inserted:
- flow.append(LoginType.TERMS)
- flows.extend(new_flows)
-
- auth_result, params, session_id = yield self.auth_handler.check_auth(
- flows, body, self.hs.get_ip_from_request(request)
+ auth_result, params, session_id = await self.auth_handler.check_auth(
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
)
# Check that we're not trying to register a denied 3pid.
@@ -579,11 +536,11 @@ class RegisterRestServlet(RestServlet):
registered = False
else:
# NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password"])
+ assert_params_in_dict(params, ["password_hash"])
desired_username = params.get("username", None)
guest_access_token = params.get("guest_access_token", None)
- new_password = params.get("password", None)
+ new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -603,7 +560,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- existing_user_id = yield self.store.get_user_id_by_threepid(
+ existing_user_id = await self.store.get_user_id_by_threepid(
medium, address
)
@@ -614,9 +571,9 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
- registered_user_id = yield self.registration_handler.register_user(
+ registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password=new_password,
+ password_hash=new_password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
@@ -627,22 +584,22 @@ class RegisterRestServlet(RestServlet):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
- yield self.store.upsert_monthly_active_user(registered_user_id)
+ await self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
- self.auth_handler.set_session_data(
+ await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
registered = True
- return_dict = yield self._create_registration_details(
+ return_dict = await self._create_registration_details(
registered_user_id, params
)
if registered:
- yield self.registration_handler.post_registration_actions(
+ await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
access_token=return_dict.get("access_token"),
@@ -653,15 +610,13 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- @defer.inlineCallbacks
- def _do_appservice_registration(self, username, as_token, body):
- user_id = yield self.registration_handler.appservice_register(
+ async def _do_appservice_registration(self, username, as_token, body):
+ user_id = await self.registration_handler.appservice_register(
username, as_token
)
- return (yield self._create_registration_details(user_id, body))
+ return await self._create_registration_details(user_id, body)
- @defer.inlineCallbacks
- def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -677,18 +632,17 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False
)
result.update({"access_token": access_token, "device_id": device_id})
return result
- @defer.inlineCallbacks
- def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
- user_id = yield self.registration_handler.register_user(
+ user_id = await self.registration_handler.register_user(
make_guest=True, address=address
)
@@ -696,7 +650,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True
)
@@ -711,6 +665,83 @@ class RegisterRestServlet(RestServlet):
)
+def _calculate_registration_flows(
+ # technically `config` has to provide *all* of these interfaces, not just one
+ config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
+ auth_handler: AuthHandler,
+) -> List[List[str]]:
+ """Get a suitable flows list for registration
+
+ Args:
+ config: server configuration
+ auth_handler: authorization handler
+
+ Returns: a list of supported flows
+ """
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = "email" in config.registrations_require_3pid
+ require_msisdn = "msisdn" in config.registrations_require_3pid
+
+ show_msisdn = True
+ show_email = True
+
+ if config.disable_msisdn_registration:
+ show_msisdn = False
+ require_msisdn = False
+
+ enabled_auth_types = auth_handler.get_enabled_auth_types()
+ if LoginType.EMAIL_IDENTITY not in enabled_auth_types:
+ show_email = False
+ if require_email:
+ raise ConfigError(
+ "Configuration requires email address at registration, but email "
+ "validation is not configured"
+ )
+
+ if LoginType.MSISDN not in enabled_auth_types:
+ show_msisdn = False
+ if require_msisdn:
+ raise ConfigError(
+ "Configuration requires msisdn at registration, but msisdn "
+ "validation is not configured"
+ )
+
+ flows = []
+
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ # Add a dummy step here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.append([LoginType.DUMMY])
+
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if show_email and not require_msisdn:
+ flows.append([LoginType.EMAIL_IDENTITY])
+
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if show_msisdn and not require_email:
+ flows.append([LoginType.MSISDN])
+
+ if show_email and show_msisdn:
+ # always let users provide both MSISDN & email
+ flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
+
+ # Prepend m.login.terms to all flows if we're requiring consent
+ if config.user_consent_at_registration:
+ for flow in flows:
+ flow.insert(0, LoginType.TERMS)
+
+ # Prepend recaptcha to all flows if we're requiring captcha
+ if config.enable_registration_captcha:
+ for flow in flows:
+ flow.insert(0, LoginType.RECAPTCHA)
+
+ return flows
+
+
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 040b37c504..89002ffbff 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -21,8 +21,6 @@ any time to reflect changes in the MSC.
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet):
request, self.on_PUT_or_POST, request, *args, **kwargs
)
- @defer.inlineCallbacks
- def on_PUT_or_POST(
+ async def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None
):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
# Add relations to a membership is meaningless, so we just deny it
@@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
@@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(
+ self, request, room_id, parent_id, relation_type=None, event_type=None
+ ):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
- room_id, requester.user.to_string()
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string(), allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
- event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
@@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet):
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
- pagination_chunk = yield self.store.get_relations_for_event(
+ pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
@@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
- events = yield self.store.get_events_as_list(
+ events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
@@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet):
# We set bundle_aggregations to False when retrieving the original
# event because we want the content before relations were applied to
# it.
- original_event = yield self._event_serializer.serialize_event(
+ original_event = await self._event_serializer.serialize_event(
event, now, bundle_aggregations=False
)
# Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them
# here.
- events = yield self._event_serializer.serialize_events(
+ events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False
)
@@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet):
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(
+ self, request, room_id, parent_id, relation_type=None, event_type=None
+ ):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
- room_id, requester.user.to_string()
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string(), allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
- event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet):
if to_token:
to_token = AggregationPaginationToken.from_string(to_token)
- pagination_chunk = yield self.store.get_aggregation_groups_for_event(
+ pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
event_type=event_type,
limit=limit,
@@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
- room_id, requester.user.to_string()
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string(), allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
- result = yield self.store.get_relations_for_event(
+ result = await self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
@@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
- events = yield self.store.get_events_as_list(
+ events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
- events = yield self._event_serializer.serialize_events(events, now)
+ events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = events
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e7449864cd..f067b5edac 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -18,8 +18,6 @@ import logging
from six import string_types
from six.moves import http_client
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON,
)
- yield self.store.add_event_report(
+ await self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index df4f44cd36..59529707df 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, session_id):
+ async def on_PUT(self, request, room_id, session_id):
"""
Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional)
@@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet):
}
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version")
@@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet):
if room_id:
body = {"rooms": {room_id: body}}
- yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
- return 200, {}
+ ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
+ return 200, ret
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, session_id):
+ async def on_GET(self, request, room_id, session_id):
"""
Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API.
@@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet):
}
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- version = parse_string(request, "version")
+ version = parse_string(request, "version", required=True)
- room_keys = yield self.e2e_room_keys_handler.get_room_keys(
+ room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
)
@@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_id, session_id):
+ async def on_DELETE(self, request, room_id, session_id):
"""
Deletes one or more encrypted E2E room keys for a user for backup purposes.
@@ -235,14 +230,14 @@ class RoomKeysServlet(RestServlet):
the version must already have been created via the /change_secret API.
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
- yield self.e2e_room_keys_handler.delete_room_keys(
+ ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
)
- return 200, {}
+ return 200, ret
class RoomKeysNewVersionServlet(RestServlet):
@@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""
Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user
@@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet):
"version": 12345
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
- new_version = yield self.e2e_room_keys_handler.create_version(user_id, info)
+ new_version = await self.e2e_room_keys_handler.create_version(user_id, info)
return 200, {"version": new_version}
# we deliberately don't have a PUT /version, as these things really should
@@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, version):
+ async def on_GET(self, request, version):
"""
Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the
@@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet):
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
try:
- info = yield self.e2e_room_keys_handler.get_version_info(user_id, version)
+ info = await self.e2e_room_keys_handler.get_version_info(user_id, version)
except SynapseError as e:
if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info
- @defer.inlineCallbacks
- def on_DELETE(self, request, version):
+ async def on_DELETE(self, request, version):
"""
Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most
@@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet):
if version is None:
raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- yield self.e2e_room_keys_handler.delete_version(user_id, version)
+ await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {}
- @defer.inlineCallbacks
- def on_PUT(self, request, version):
+ async def on_PUT(self, request, version):
"""
Update the information about a given version of the user's room_keys backup.
@@ -375,14 +366,14 @@ class RoomKeysVersionServlet(RestServlet):
"ed25519:something": "hijklmnop"
}
},
- "version": "42"
+ "version": "12345"
}
HTTP/1.1 200 OK
Content-Type: application/json
{}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
@@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet):
400, "No version specified to update", Codes.MISSING_PARAM
)
- yield self.e2e_room_keys_handler.update_version(user_id, version, info)
+ await self.e2e_room_keys_handler.update_version(user_id, version, info)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index d2c3316eb7..f357015a70 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
@@ -59,22 +57,22 @@ class RoomUpgradeRestServlet(RestServlet):
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self._auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
new_version = content["new_version"]
- if new_version not in KNOWN_ROOM_VERSIONS:
+ new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
+ if new_version is None:
raise SynapseError(
400,
"Your homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- new_room_id = yield self._room_creation_handler.upgrade_room(
+ new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d90e52ed1a..db829f3098 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -14,8 +14,7 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Tuple
from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
@@ -51,19 +50,18 @@ class SendToDeviceRestServlet(servlet.RestServlet):
request, self._put, request, message_type, txn_id
)
- @defer.inlineCallbacks
- def _put(self, request, message_type, txn_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def _put(self, request, message_type, txn_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
sender_user_id = requester.user.to_string()
- yield self.device_message_handler.send_device_message(
+ await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"]
)
- response = (200, {})
+ response = (200, {}) # type: Tuple[int, dict]
return response
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index c98c5a3802..8fa68dd37f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -18,10 +18,8 @@ import logging
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import PresenceState
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
@@ -74,7 +72,7 @@ class SyncRestServlet(RestServlet):
"""
PATTERNS = client_patterns("/sync$")
- ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
+ ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
@@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
@@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet):
400, "'from' is not a valid query parameter. Did you mean 'since'?"
)
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
device_id = requester.device_id
@@ -112,32 +109,44 @@ class SyncRestServlet(RestServlet):
full_state = parse_boolean(request, "full_state", default=False)
logger.debug(
- "/sync: user=%r, timeout=%r, since=%r,"
- " set_presence=%r, filter_id=%r, device_id=%r"
- % (user, timeout, since, set_presence, filter_id, device_id)
+ "/sync: user=%r, timeout=%r, since=%r, "
+ "set_presence=%r, filter_id=%r, device_id=%r",
+ user,
+ timeout,
+ since,
+ set_presence,
+ filter_id,
+ device_id,
)
request_key = (user, timeout, since, filter_id, full_state, device_id)
- if filter_id:
- if filter_id.startswith("{"):
- try:
- filter_object = json.loads(filter_id)
- set_timeline_upper_limit(
- filter_object, self.hs.config.filter_timeline_limit
- )
- except Exception:
- raise SynapseError(400, "Invalid filter JSON")
- self.filtering.check_valid_filter(filter_object)
- filter = FilterCollection(filter_object)
- else:
- filter = yield self.filtering.get_user_filter(user.localpart, filter_id)
+ if filter_id is None:
+ filter_collection = DEFAULT_FILTER_COLLECTION
+ elif filter_id.startswith("{"):
+ try:
+ filter_object = json.loads(filter_id)
+ set_timeline_upper_limit(
+ filter_object, self.hs.config.filter_timeline_limit
+ )
+ except Exception:
+ raise SynapseError(400, "Invalid filter JSON")
+ self.filtering.check_valid_filter(filter_object)
+ filter_collection = FilterCollection(filter_object)
else:
- filter = DEFAULT_FILTER_COLLECTION
+ try:
+ filter_collection = await self.filtering.get_user_filter(
+ user.localpart, filter_id
+ )
+ except StoreError as err:
+ if err.code != 404:
+ raise
+ # fix up the description and errcode to be more useful
+ raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM)
sync_config = SyncConfig(
user=user,
- filter_collection=filter,
+ filter_collection=filter_collection,
is_guest=requester.is_guest,
request_key=request_key,
device_id=device_id,
@@ -149,20 +158,20 @@ class SyncRestServlet(RestServlet):
since_token = None
# send any outstanding server notices to the user.
- yield self._server_notices_sender.on_user_syncing(user.to_string())
+ await self._server_notices_sender.on_user_syncing(user.to_string())
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
- yield self.presence_handler.set_state(
+ await self.presence_handler.set_state(
user, {"presence": set_presence}, True
)
- context = yield self.presence_handler.user_syncing(
+ context = await self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence
)
with context:
- sync_result = yield self.sync_handler.wait_for_sync_for_user(
+ sync_result = await self.sync_handler.wait_for_sync_for_user(
sync_config,
since_token=since_token,
timeout=timeout,
@@ -170,14 +179,13 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
- response_content = yield self.encode_response(
- time_now, sync_result, requester.access_token_id, filter
+ response_content = await self.encode_response(
+ time_now, sync_result, requester.access_token_id, filter_collection
)
return 200, response_content
- @defer.inlineCallbacks
- def encode_response(self, time_now, sync_result, access_token_id, filter):
+ async def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == "federation":
@@ -185,7 +193,7 @@ class SyncRestServlet(RestServlet):
else:
raise Exception("Unknown event format %s" % (filter.event_format,))
- joined = yield self.encode_joined(
+ joined = await self.encode_joined(
sync_result.joined,
time_now,
access_token_id,
@@ -193,11 +201,11 @@ class SyncRestServlet(RestServlet):
event_formatter,
)
- invited = yield self.encode_invited(
+ invited = await self.encode_invited(
sync_result.invited, time_now, access_token_id, event_formatter
)
- archived = yield self.encode_archived(
+ archived = await self.encode_archived(
sync_result.archived,
time_now,
access_token_id,
@@ -238,8 +246,9 @@ class SyncRestServlet(RestServlet):
]
}
- @defer.inlineCallbacks
- def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter):
+ async def encode_joined(
+ self, rooms, time_now, token_id, event_fields, event_formatter
+ ):
"""
Encode the joined rooms in a sync result
@@ -260,7 +269,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = yield self.encode_room(
+ joined[room.room_id] = await self.encode_room(
room,
time_now,
token_id,
@@ -271,8 +280,7 @@ class SyncRestServlet(RestServlet):
return joined
- @defer.inlineCallbacks
- def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(self, rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -292,7 +300,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = yield self._event_serializer.serialize_event(
+ invite = await self._event_serializer.serialize_event(
room.invite,
time_now,
token_id=token_id,
@@ -307,8 +315,9 @@ class SyncRestServlet(RestServlet):
return invited
- @defer.inlineCallbacks
- def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
+ async def encode_archived(
+ self, rooms, time_now, token_id, event_fields, event_formatter
+ ):
"""
Encode the archived rooms in a sync result
@@ -329,7 +338,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = yield self.encode_room(
+ joined[room.room_id] = await self.encode_room(
room,
time_now,
token_id,
@@ -340,8 +349,7 @@ class SyncRestServlet(RestServlet):
return joined
- @defer.inlineCallbacks
- def encode_room(
+ async def encode_room(
self, room, time_now, token_id, joined, only_fields, event_formatter
):
"""
@@ -382,15 +390,15 @@ class SyncRestServlet(RestServlet):
# We've had bug reports that events were coming down under the
# wrong room.
if event.room_id != room.room_id:
- logger.warn(
+ logger.warning(
"Event %r is under room %r instead of %r",
event.event_id,
room.room_id,
event.room_id,
)
- serialized_state = yield serialize(state_events)
- serialized_timeline = yield serialize(timeline_events)
+ serialized_state = await serialize(state_events)
+ serialized_timeline = await serialize(timeline_events)
account_data = room.account_data
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 3b555669a0..a3f12e8a77 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -37,13 +35,12 @@ class TagListServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
return 200, {"tags": tags}
@@ -64,27 +61,25 @@ class TagServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, room_id, tag):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id, room_id, tag):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
body = parse_json_object_from_request(request)
- max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
+ max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, user_id, room_id, tag):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, user_id, room_id, tag):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
+ max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 2e8d672471..23709960ad 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
@@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- protocols = yield self.appservice_handler.get_3pe_protocols()
+ protocols = await self.appservice_handler.get_3pe_protocols()
return 200, protocols
@@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- protocols = yield self.appservice_handler.get_3pe_protocols(
+ protocols = await self.appservice_handler.get_3pe_protocols(
only_protocol=protocol
)
if protocol in protocols:
@@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop(b"access_token", None)
- results = yield self.appservice_handler.query_3pe(
+ results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
@@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop(b"access_token", None)
- results = yield self.appservice_handler.query_3pe(
+ results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 2da0f55811..83f3b6b70a 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet
@@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 2863affbab..bef91a2d3e 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""Searches for users in directory
Returns:
@@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet):
]
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled:
@@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet):
except Exception:
raise SynapseError(400, "`search_term` is required field")
- results = yield self.user_directory_handler.search_users(
+ results = await self.user_directory_handler.search_users(
user_id, search_term, limit
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0058b6b459..0d668df0b6 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,9 +49,18 @@ class VersionsRestServlet(RestServlet):
"r0.3.0",
"r0.4.0",
"r0.5.0",
+ "r0.6.0",
],
# as per MSC1497:
- "unstable_features": {"m.lazy_load_members": True},
+ "unstable_features": {
+ # Implements support for label-based filtering as described in
+ # MSC2326.
+ "org.matrix.label_based_filtering": True,
+ # Implements support for cross signing as described in MSC1756
+ "org.matrix.e2e_cross_signing": True,
+ # Implements additional endpoints as described in MSC2432
+ "org.matrix.msc2432": True,
+ },
},
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 55580bc59e..ab671f7334 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,12 +13,11 @@
# limitations under the License.
import logging
+from typing import Dict, Set
from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import (
@@ -102,8 +101,8 @@ class RemoteKey(DirectServeResource):
@wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
- server, = request.postpath
- query = {server.decode("ascii"): {}}
+ (server,) = request.postpath
+ query = {server.decode("ascii"): {}} # type: dict
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@@ -124,8 +123,7 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- @defer.inlineCallbacks
- def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
store_queries = []
@@ -142,13 +140,13 @@ class RemoteKey(DirectServeResource):
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
- cached = yield self.store.get_server_keys_json(store_queries)
+ cached = await self.store.get_server_keys_json(store_queries)
json_results = set()
time_now_ms = self.clock.time_msec()
- cache_misses = dict()
+ cache_misses = {} # type: Dict[str, Set[str]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
@@ -214,8 +212,8 @@ class RemoteKey(DirectServeResource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
- yield self.fetcher.get_keys(cache_misses)
- yield self.query_keys(request, query, query_remote_on_cache_miss=False)
+ await self.fetcher.get_keys(cache_misses)
+ await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
diff --git a/synapse/rest/media/v0/__init__.py b/synapse/rest/media/v0/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/synapse/rest/media/v0/__init__.py
+++ /dev/null
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
deleted file mode 100644
index 86884c0ef4..0000000000
--- a/synapse/rest/media/v0/content_repository.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import logging
-import os
-import re
-
-from canonicaljson import json
-
-from twisted.protocols.basic import FileSender
-from twisted.web import resource, server
-
-from synapse.api.errors import Codes, cs_error
-from synapse.http.server import finish_request, respond_with_json_bytes
-
-logger = logging.getLogger(__name__)
-
-
-class ContentRepoResource(resource.Resource):
- """Provides file uploading and downloading.
-
- Uploads are POSTed to wherever this Resource is linked to. This resource
- returns a "content token" which can be used to GET this content again. The
- token is typically a path, but it may not be. Tokens can expire, be
- one-time uses, etc.
-
- In this case, the token is a path to the file and contains 3 interesting
- sections:
- - User ID base64d (for namespacing content to each user)
- - random 24 char string
- - Content type base64d (so we can return it when clients GET it)
-
- """
-
- isLeaf = True
-
- def __init__(self, hs, directory):
- resource.Resource.__init__(self)
- self.hs = hs
- self.directory = directory
-
- def render_GET(self, request):
- # no auth here on purpose, to allow anyone to view, even across home
- # servers.
-
- # TODO: A little crude here, we could do this better.
- filename = request.path.decode("ascii").split("/")[-1]
- # be paranoid
- filename = re.sub("[^0-9A-z.-_]", "", filename)
-
- file_path = self.directory + "/" + filename
-
- logger.debug("Searching for %s", file_path)
-
- if os.path.isfile(file_path):
- # filename has the content type
- base64_contentype = filename.split(".")[1]
- content_type = base64.urlsafe_b64decode(base64_contentype)
- logger.info("Sending file %s", file_path)
- f = open(file_path, "rb")
- request.setHeader("Content-Type", content_type)
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our matrix
- # clients are smart enough to be happy with Cache-Control (right?)
- request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
-
- d = FileSender().beginFileTransfer(f, request)
-
- # after the file has been sent, clean up and finish the request
- def cbFinished(ignored):
- f.close()
- finish_request(request)
-
- d.addCallback(cbFinished)
- else:
- respond_with_json_bytes(
- request,
- 404,
- json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)),
- send_cors=True,
- )
-
- return server.NOT_DONE_YET
-
- def render_OPTIONS(self, request):
- respond_with_json_bytes(request, 200, {}, send_cors=True)
- return server.NOT_DONE_YET
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 5fefee4dde..3689777266 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,6 @@
import logging
import os
-from six import PY3
from six.moves import urllib
from twisted.internet import defer
@@ -30,6 +29,22 @@ from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__)
+# list all text content types that will have the charset default to UTF-8 when
+# none is given
+TEXT_CONTENT_TYPES = [
+ "text/css",
+ "text/csv",
+ "text/html",
+ "text/calendar",
+ "text/plain",
+ "text/javascript",
+ "application/json",
+ "application/ld+json",
+ "application/rtf",
+ "image/svg+xml",
+ "text/xml",
+]
+
def parse_media_id(request):
try:
@@ -96,7 +111,14 @@ def add_file_headers(request, media_type, file_size, upload_name):
def _quote(x):
return urllib.parse.quote(x.encode("utf-8"))
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ # Default to a UTF-8 charset for text content types.
+ # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
+ if media_type.lower() in TEXT_CONTENT_TYPES:
+ content_type = media_type + "; charset=UTF-8"
+ else:
+ content_type = media_type
+
+ request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
if upload_name:
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
#
@@ -135,27 +157,25 @@ def add_file_headers(request, media_type, file_size, upload_name):
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
-_FILENAME_SEPARATOR_CHARS = set(
- (
- "(",
- ")",
- "<",
- ">",
- "@",
- ",",
- ";",
- ":",
- "\\",
- '"',
- "/",
- "[",
- "]",
- "?",
- "=",
- "{",
- "}",
- )
-)
+_FILENAME_SEPARATOR_CHARS = {
+ "(",
+ ")",
+ "<",
+ ">",
+ "@",
+ ",",
+ ";",
+ ":",
+ "\\",
+ '"',
+ "/",
+ "[",
+ "]",
+ "?",
+ "=",
+ "{",
+ "}",
+}
def _can_encode_filename_as_token(x):
@@ -195,7 +215,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
- logger.debug("Responding to media request with responder %s")
+ logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
@@ -303,23 +323,15 @@ def get_filename_from_headers(headers):
upload_name_utf8 = upload_name_utf8[7:]
# We have a filename*= section. This MUST be ASCII, and any UTF-8
# bytes are %-quoted.
- if PY3:
- try:
- # Once it is decoded, we can then unquote the %-encoded
- # parts strictly into a unicode string.
- upload_name = urllib.parse.unquote(
- upload_name_utf8.decode("ascii"), errors="strict"
- )
- except UnicodeDecodeError:
- # Incorrect UTF-8.
- pass
- else:
- # On Python 2, we first unquote the %-encoded parts and then
- # decode it strictly using UTF-8.
- try:
- upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
- except UnicodeDecodeError:
- pass
+ try:
+ # Once it is decoded, we can then unquote the %-encoded
+ # parts strictly into a unicode string.
+ upload_name = urllib.parse.unquote(
+ upload_name_utf8.decode("ascii"), errors="strict"
+ )
+ except UnicodeDecodeError:
+ # Incorrect UTF-8.
+ pass
# If there isn't check for an ascii name.
if not upload_name:
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 66a01559e1..24d3ae5bbc 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -50,6 +50,9 @@ class DownloadResource(DirectServeResource):
b" media-src 'self';"
b" object-src 'self';",
)
+ request.setHeader(
+ b"Referrer-Policy", b"no-referrer",
+ )
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index b972e152a9..fd10d42f2f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,12 +18,12 @@ import errno
import logging
import os
import shutil
+from typing import Dict, Tuple
from six import iteritems
import twisted.internet.error
import twisted.web.http
-from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -113,15 +113,14 @@ class MediaRepository(object):
"update_recently_accessed_media", self._update_recently_accessed
)
- @defer.inlineCallbacks
- def _update_recently_accessed(self):
+ async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
- yield self.store.update_cached_last_access_time(
+ await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)
@@ -137,8 +136,7 @@ class MediaRepository(object):
else:
self.recently_accessed_locals.add(media_id)
- @defer.inlineCallbacks
- def create_content(
+ async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
@@ -157,11 +155,11 @@ class MediaRepository(object):
file_info = FileInfo(server_name=None, file_id=media_id)
- fname = yield self.media_storage.store_file(content, file_info)
+ fname = await self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
- yield self.store.store_local_media(
+ await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -170,12 +168,11 @@ class MediaRepository(object):
user_id=auth_user,
)
- yield self._generate_thumbnails(None, media_id, media_id, media_type)
+ await self._generate_thumbnails(None, media_id, media_id, media_type)
return "mxc://%s/%s" % (self.server_name, media_id)
- @defer.inlineCallbacks
- def get_local_media(self, request, media_id, name):
+ async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
@@ -189,7 +186,7 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
@@ -203,13 +200,12 @@ class MediaRepository(object):
file_info = FileInfo(None, media_id, url_cache=url_cache)
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
- @defer.inlineCallbacks
- def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
@@ -235,8 +231,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (yield self.remote_media_linearizer.queue(key)):
- responder, media_info = yield self._get_remote_media_impl(
+ with (await self.remote_media_linearizer.queue(key)):
+ responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -245,14 +241,13 @@ class MediaRepository(object):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
- yield respond_with_responder(
+ await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
- @defer.inlineCallbacks
- def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
@@ -273,8 +268,8 @@ class MediaRepository(object):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
- with (yield self.remote_media_linearizer.queue(key)):
- responder, media_info = yield self._get_remote_media_impl(
+ with (await self.remote_media_linearizer.queue(key)):
+ responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)
@@ -285,8 +280,7 @@ class MediaRepository(object):
return media_info
- @defer.inlineCallbacks
- def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -298,7 +292,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
- media_info = yield self.store.get_cached_remote_media(server_name, media_id)
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -316,19 +310,18 @@ class MediaRepository(object):
logger.info("Media is quarantined")
raise NotFoundError()
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info
# Failed to find the file anywhere, lets download it.
- media_info = yield self._download_remote_file(server_name, media_id, file_id)
+ media_info = await self._download_remote_file(server_name, media_id, file_id)
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- @defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -350,7 +343,7 @@ class MediaRepository(object):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
- length, headers = yield self.client.get_file(
+ length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
@@ -363,7 +356,7 @@ class MediaRepository(object):
},
)
except RequestSendFailed as e:
- logger.warn(
+ logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
@@ -372,7 +365,7 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
- logger.warn(
+ logger.warning(
"HTTP error fetching remote media %s/%s: %s",
server_name,
media_id,
@@ -383,10 +376,12 @@ class MediaRepository(object):
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
- logger.warn("Failed to fetch remote media %s/%s", server_name, media_id)
+ logger.warning(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
raise
except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
+ logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
@@ -394,7 +389,7 @@ class MediaRepository(object):
)
raise SynapseError(502, "Failed to fetch remote media")
- yield finish()
+ await finish()
media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
@@ -402,7 +397,7 @@ class MediaRepository(object):
logger.info("Stored remote media in file %r", fname)
- yield self.store.store_cached_remote_media(
+ await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
@@ -420,7 +415,7 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
+ await self._generate_thumbnails(server_name, media_id, file_id, media_type)
return media_info
@@ -455,16 +450,15 @@ class MediaRepository(object):
return t_byte_source
- @defer.inlineCallbacks
- def generate_local_exact_thumbnail(
+ async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -487,7 +481,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -497,22 +491,21 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
- yield self.store.store_local_thumbnail(
+ await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return output_path
- @defer.inlineCallbacks
- def generate_remote_exact_thumbnail(
+ async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
@@ -534,7 +527,7 @@ class MediaRepository(object):
thumbnail_type=t_type,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -544,7 +537,7 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
- yield self.store.store_remote_media_thumbnail(
+ await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -557,8 +550,7 @@ class MediaRepository(object):
return output_path
- @defer.inlineCallbacks
- def _generate_thumbnails(
+ async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
@@ -579,7 +571,7 @@ class MediaRepository(object):
if not requirements:
return
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
@@ -597,13 +589,13 @@ class MediaRepository(object):
return
if thumbnailer.transpose_method is not None:
- m_width, m_height = yield defer_to_thread(
+ m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
- thumbnails = {}
+ thumbnails = {} # type: Dict[Tuple[int, int, str], str]
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
@@ -617,11 +609,11 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
- t_byte_source = yield defer_to_thread(
+ t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
@@ -643,7 +635,7 @@ class MediaRepository(object):
url_cache=url_cache,
)
- output_path = yield self.media_storage.store_file(
+ output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
@@ -653,7 +645,7 @@ class MediaRepository(object):
# Write to database
if server_name:
- yield self.store.store_remote_media_thumbnail(
+ await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
@@ -664,15 +656,14 @@ class MediaRepository(object):
t_len,
)
else:
- yield self.store.store_local_thumbnail(
+ await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
- @defer.inlineCallbacks
- def delete_old_remote_media(self, before_ts):
- old_media = yield self.store.get_remote_media_before(before_ts)
+ async def delete_old_remote_media(self, before_ts):
+ old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -686,12 +677,12 @@ class MediaRepository(object):
# TODO: Should we delete from the backup store
- with (yield self.remote_media_linearizer.queue(key)):
+ with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
except OSError as e:
- logger.warn("Failed to remove file: %r", full_path)
+ logger.warning("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT:
pass
else:
@@ -702,7 +693,7 @@ class MediaRepository(object):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
- yield self.store.delete_remote_media(origin, media_id)
+ await self.store.delete_remote_media(origin, media_id)
deleted += 1
return {"deleted": deleted}
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3b87717a5a..683a79c966 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -148,6 +148,7 @@ class MediaStorage(object):
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
+ logger.debug("Streaming %s from %s", path, provider)
return res
return None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 7a56cd4b6c..f206605727 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -23,6 +23,7 @@ import re
import shutil
import sys
import traceback
+from typing import Dict, Optional
import six
from six import string_types
@@ -56,6 +57,9 @@ logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+OG_TAG_NAME_MAXLEN = 50
+OG_TAG_VALUE_MAXLEN = 1000
+
class PreviewUrlResource(DirectServeResource):
isLeaf = True
@@ -74,12 +78,15 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
+ self.url_preview_accept_language = hs.config.url_preview_accept_language
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
@@ -117,8 +124,10 @@ class PreviewUrlResource(DirectServeResource):
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug(
- ("Matching attrib '%s' with value '%s' against" " pattern '%s'")
- % (attrib, value, pattern)
+ "Matching attrib '%s' with value '%s' against pattern '%s'",
+ attrib,
+ value,
+ pattern,
)
if value is None:
@@ -134,7 +143,7 @@ class PreviewUrlResource(DirectServeResource):
match = False
continue
if match:
- logger.warn("URL %s blocked by url_blacklist entry %s", url, entry)
+ logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
)
@@ -157,8 +166,7 @@ class PreviewUrlResource(DirectServeResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
- @defer.inlineCallbacks
- def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url, user, ts):
"""Check the db, and download the URL and build a preview
Args:
@@ -167,11 +175,11 @@ class PreviewUrlResource(DirectServeResource):
ts (int):
Returns:
- Deferred[str]: json-encoded og data
+ Deferred[bytes]: json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
- cache_result = yield self.store.get_url_cache(url, ts)
+ cache_result = await self.store.get_url_cache(url, ts)
if (
cache_result
and cache_result["expires_ts"] > ts
@@ -184,13 +192,13 @@ class PreviewUrlResource(DirectServeResource):
og = og.encode("utf8")
return og
- media_info = yield self._download_url(url, user)
+ media_info = await self._download_url(url, user)
- logger.debug("got media_info of '%s'" % media_info)
+ logger.debug("got media_info of '%s'", media_info)
if _is_media(media_info["media_type"]):
file_id = media_info["filesystem_id"]
- dims = yield self.media_repo._generate_thumbnails(
+ dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, media_info["media_type"], url_cache=True
)
@@ -206,7 +214,7 @@ class PreviewUrlResource(DirectServeResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % url)
+ logger.warning("Couldn't get dims for %s" % url)
# define our OG response for this media
elif _is_html(media_info["media_type"]):
@@ -230,8 +238,8 @@ class PreviewUrlResource(DirectServeResource):
# If we don't find a match, we'll look at the HTTP Content-Type, and
# if that doesn't exist, we'll fall back to UTF-8.
if not encoding:
- match = _content_type_match.match(media_info["media_type"])
- encoding = match.group(1) if match else "utf-8"
+ content_match = _content_type_match.match(media_info["media_type"])
+ encoding = content_match.group(1) if content_match else "utf-8"
og = decode_and_calc_og(body, media_info["uri"], encoding)
@@ -240,21 +248,21 @@ class PreviewUrlResource(DirectServeResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if "og:image" in og and og["og:image"]:
- image_info = yield self._download_url(
+ image_info = await self._download_url(
_rebase_url(og["og:image"], media_info["uri"]), user
)
if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info["filesystem_id"]
- dims = yield self.media_repo._generate_thumbnails(
+ dims = await self.media_repo._generate_thumbnails(
None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % og["og:image"])
+ logger.warning("Couldn't get dims for %s", og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name,
@@ -265,15 +273,27 @@ class PreviewUrlResource(DirectServeResource):
else:
del og["og:image"]
else:
- logger.warn("Failed to find any OG data in %s", url)
+ logger.warning("Failed to find any OG data in %s", url)
og = {}
- logger.debug("Calculated OG for %s as %s" % (url, og))
+ # filter out any stupidly long values
+ keys_to_remove = []
+ for k, v in og.items():
+ # values can be numeric as well as strings, hence the cast to str
+ if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN:
+ logger.warning(
+ "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN]
+ )
+ keys_to_remove.append(k)
+ for k in keys_to_remove:
+ del og[k]
+
+ logger.debug("Calculated OG for %s as %s", url, og)
- jsonog = json.dumps(og).encode("utf8")
+ jsonog = json.dumps(og)
# store OG in history-aware DB cache
- yield self.store.store_url_cache(
+ await self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
@@ -283,10 +303,9 @@ class PreviewUrlResource(DirectServeResource):
media_info["created_ts"],
)
- return jsonog
+ return jsonog.encode("utf8")
- @defer.inlineCallbacks
- def _download_url(self, url, user):
+ async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -297,9 +316,12 @@ class PreviewUrlResource(DirectServeResource):
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- logger.debug("Trying to get url '%s'" % url)
- length, headers, uri, code = yield self.client.get_file(
- url, output_stream=f, max_size=self.max_spider_size
+ logger.debug("Trying to get preview for url '%s'", url)
+ length, headers, uri, code = await self.client.get_file(
+ url,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
)
except SynapseError:
# Pass SynapseErrors through directly, so that the servlet
@@ -317,7 +339,7 @@ class PreviewUrlResource(DirectServeResource):
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
- logger.warn("Error downloading %s: %r", url, e)
+ logger.warning("Error downloading %s: %r", url, e)
raise SynapseError(
500,
@@ -325,7 +347,7 @@ class PreviewUrlResource(DirectServeResource):
% (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
- yield finish()
+ await finish()
try:
if b"Content-Type" in headers:
@@ -336,7 +358,7 @@ class PreviewUrlResource(DirectServeResource):
download_name = get_filename_from_headers(headers)
- yield self.store.store_local_media(
+ await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
@@ -373,22 +395,21 @@ class PreviewUrlResource(DirectServeResource):
"expire_url_cache_data", self._expire_url_cache_data
)
- @defer.inlineCallbacks
- def _expire_url_cache_data(self):
+ async def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
now = self.clock.time_msec()
- logger.info("Running url preview cache expiry")
+ logger.debug("Running url preview cache expiry")
- if not (yield self.store.has_completed_background_updates()):
+ if not (await self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries
- media_ids = yield self.store.get_expired_url_cache(now)
+ media_ids = await self.store.get_expired_url_cache(now)
removed_media = []
for media_id in media_ids:
@@ -398,7 +419,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
@@ -410,17 +431,19 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
- yield self.store.delete_url_cache(removed_media)
+ await self.store.delete_url_cache(removed_media)
if removed_media:
logger.info("Deleted %d entries from url cache", len(removed_media))
+ else:
+ logger.debug("No entries removed from url cache")
# Now we delete old images associated with the url cache.
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
expire_before = now - 2 * 24 * 60 * 60 * 1000
- media_ids = yield self.store.get_url_cache_media_before(expire_before)
+ media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
for media_id in media_ids:
@@ -430,7 +453,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
try:
@@ -446,7 +469,7 @@ class PreviewUrlResource(DirectServeResource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
@@ -458,9 +481,12 @@ class PreviewUrlResource(DirectServeResource):
except Exception:
pass
- yield self.store.delete_url_cache_media(removed_media)
+ await self.store.delete_url_cache_media(removed_media)
- logger.info("Deleted %d media from url cache", len(removed_media))
+ if removed_media:
+ logger.info("Deleted %d media from url cache", len(removed_media))
+ else:
+ logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None):
@@ -499,9 +525,13 @@ def _calc_og(tree, media_uri):
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
- og = {}
+ og = {} # type: Dict[str, Optional[str]]
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
+ # if we've got more than 50 tags, someone is taking the piss
+ if len(og) >= 50:
+ logger.warning("Skipping OG for page with too many 'og:' tags")
+ return {}
og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 37687ea7f4..858680be26 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -77,6 +77,9 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
+ def __str__(self):
+ return "StorageProviderWrapper[%s]" % (self.backend,)
+
def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
return defer.succeed(None)
@@ -114,6 +117,9 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
+ def __str__(self):
+ return "FileStorageProviderBackend[%s]" % (self.base_directory,)
+
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 08329884ac..0b87220234 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
@@ -79,11 +77,10 @@ class ThumbnailResource(DirectServeResource):
)
self.media_repo.mark_recently_accessed(server_name, media_id)
- @defer.inlineCallbacks
- def _respond_local_thumbnail(
+ async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -93,7 +90,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
- thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
@@ -114,14 +111,13 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(request, responder, t_type, t_length)
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
- @defer.inlineCallbacks
- def _select_or_generate_local_thumbnail(
+ async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
@@ -130,7 +126,7 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
- media_info = yield self.store.get_local_media(media_id)
+ media_info = await self.store.get_local_media(media_id)
if not media_info:
respond_404(request)
@@ -140,7 +136,7 @@ class ThumbnailResource(DirectServeResource):
respond_404(request)
return
- thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
@@ -162,15 +158,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
- yield respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
- file_path = yield self.media_repo.generate_local_exact_thumbnail(
+ file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
@@ -180,13 +176,12 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
- yield respond_with_file(request, desired_type, file_path)
+ await respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
- @defer.inlineCallbacks
- def _select_or_generate_remote_thumbnail(
+ async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
@@ -196,9 +191,9 @@ class ThumbnailResource(DirectServeResource):
desired_method,
desired_type,
):
- media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+ media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
- thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -224,15 +219,15 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
+ responder = await self.media_storage.fetch_media(file_info)
if responder:
- yield respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
- file_path = yield self.media_repo.generate_remote_exact_thumbnail(
+ file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
@@ -243,21 +238,20 @@ class ThumbnailResource(DirectServeResource):
)
if file_path:
- yield respond_with_file(request, desired_type, file_path)
+ await respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
- @defer.inlineCallbacks
- def _respond_remote_thumbnail(
+ async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
- media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
+ media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
- thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
@@ -278,8 +272,8 @@ class ThumbnailResource(DirectServeResource):
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- responder = yield self.media_storage.fetch_media(file_info)
- yield respond_with_responder(request, responder, t_type, t_length)
+ responder = await self.media_storage.fetch_media(file_info)
+ await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
@@ -296,8 +290,8 @@ class ThumbnailResource(DirectServeResource):
d_h = desired_height
if desired_method.lower() == "crop":
- info_list = []
- info_list2 = []
+ crop_info_list = []
+ crop_info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
@@ -309,7 +303,7 @@ class ThumbnailResource(DirectServeResource):
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
- info_list.append(
+ crop_info_list.append(
(
aspect_quality,
min_quality,
@@ -320,7 +314,7 @@ class ThumbnailResource(DirectServeResource):
)
)
else:
- info_list2.append(
+ crop_info_list2.append(
(
aspect_quality,
min_quality,
@@ -330,10 +324,10 @@ class ThumbnailResource(DirectServeResource):
info,
)
)
- if info_list:
- return min(info_list)[-1]
+ if crop_info_list:
+ return min(crop_info_list)[-1]
else:
- return min(info_list2)[-1]
+ return min(crop_info_list2)[-1]
else:
info_list = []
info_list2 = []
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index c995d7e043..c234ea7421 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -82,13 +82,21 @@ class Thumbnailer(object):
else:
return (max_height * self.width) // self.height, max_height
+ def _resize(self, width, height):
+ # 1-bit or 8-bit color palette images need converting to RGB
+ # otherwise they will be scaled using nearest neighbour which
+ # looks awful
+ if self.image.mode in ["1", "P"]:
+ self.image = self.image.convert("RGB")
+ return self.image.resize((width, height), Image.ANTIALIAS)
+
def scale(self, width, height, output_type):
"""Rescales the image to the given dimensions.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
- scaled = self.image.resize((width, height), Image.ANTIALIAS)
+ scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
@@ -107,13 +115,13 @@ class Thumbnailer(object):
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
- scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS)
+ scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
- scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS)
+ scaled_image = self._resize(scaled_width, height)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
@@ -121,5 +129,8 @@ class Thumbnailer(object):
def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO()
- output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
+ fmt = self.FORMATS[output_type]
+ if fmt == "JPEG":
+ output_image = output_image.convert("RGB")
+ output_image.save(output_bytes_io, fmt, quality=80)
return output_bytes_io
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 5d76bbdf68..83d005812d 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -17,7 +17,7 @@ import logging
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
from synapse.http.server import (
DirectServeResource,
respond_with_json,
@@ -56,7 +56,11 @@ class UploadResource(DirectServeResource):
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
- raise SynapseError(msg="Upload request body is too large", code=413)
+ raise SynapseError(
+ msg="Upload request body is too large",
+ code=413,
+ errcode=Codes.TOO_LARGE,
+ )
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py
new file mode 100644
index 0000000000..d958dd65bb
--- /dev/null
+++ b/synapse/rest/oidc/__init__.py
@@ -0,0 +1,27 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.web.resource import Resource
+
+from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCResource(Resource):
+ def __init__(self, hs):
+ Resource.__init__(self)
+ self.putChild(b"callback", OIDCCallbackResource(hs))
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
new file mode 100644
index 0000000000..c03194f001
--- /dev/null
+++ b/synapse/rest/oidc/callback_resource.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.http.server import DirectServeResource, wrap_html_request_handler
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCCallbackResource(DirectServeResource):
+ isLeaf = 1
+
+ def __init__(self, hs):
+ super().__init__()
+ self._oidc_handler = hs.get_oidc_handler()
+
+ @wrap_html_request_handler
+ async def _async_render_GET(self, request):
+ return await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 69ecc5e4b4..75e58043b4 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.python import failure
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.api.errors import SynapseError
+from synapse.http.server import DirectServeResource, return_html_error
class SAML2ResponseResource(DirectServeResource):
@@ -25,7 +27,21 @@ class SAML2ResponseResource(DirectServeResource):
def __init__(self, hs):
super().__init__()
self._saml_handler = hs.get_saml_handler()
+ self._error_html_template = hs.config.saml2.saml2_error_html_template
+
+ async def _async_render_GET(self, request):
+ # We're not expecting any GET request on that resource if everything goes right,
+ # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
+ # In this case, just tell the user that something went wrong and they should
+ # try to authenticate again.
+ f = failure.Failure(
+ SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+ )
+ return_html_error(f, request, self._error_html_template)
- @wrap_html_request_handler
async def _async_render_POST(self, request):
- return await self._saml_handler.handle_saml_response(request)
+ try:
+ await self._saml_handler.handle_saml_response(request)
+ except Exception:
+ f = failure.Failure()
+ return_html_error(f, request, self._error_html_template)
|