summary refs log tree commit diff
path: root/synapse/rest/client/v1
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1')
-rw-r--r--synapse/rest/client/v1/login.py104
-rw-r--r--synapse/rest/client/v1/presence.py8
-rw-r--r--synapse/rest/client/v1/pusher.py10
-rw-r--r--synapse/rest/client/v1/room.py31
-rw-r--r--synapse/rest/client/v1/voip.py2
5 files changed, 95 insertions, 60 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py

index dceb2792fa..379f668d6f 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter @@ -26,8 +27,9 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder -from synapse.types import UserID +from synapse.types import JsonDict, UserID from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) @@ -60,10 +62,18 @@ def login_id_thirdparty_from_phone(identifier): Returns: Login identifier dict of type 'm.id.threepid' """ - if "country" not in identifier or "number" not in identifier: + if "country" not in identifier or ( + # The specification requires a "phone" field, while Synapse used to require a "number" + # field. Accept both for backwards compatibility. + "phone" not in identifier + and "number" not in identifier + ): raise SynapseError(400, "Invalid phone-type identifier") - msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) + # Accept both "phone" and "number" as valid keys in m.id.phone + phone_number = identifier.get("phone", identifier["number"]) + + msisdn = phone_number_to_msisdn(identifier["country"], phone_number) return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} @@ -73,17 +83,25 @@ class LoginRestServlet(RestServlet): CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" - JWT_TYPE = "m.login.jwt" + JWT_TYPE = "org.matrix.login.jwt" + JWT_TYPE_DEPRECATED = "m.login.jwt" def __init__(self, hs): super(LoginRestServlet, self).__init__() self.hs = hs + + # JWT configuration variables. self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm + self.jwt_issuer = hs.config.jwt_issuer + self.jwt_audiences = hs.config.jwt_audiences + + # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -104,10 +122,11 @@ class LoginRestServlet(RestServlet): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) - def on_GET(self, request): + def on_GET(self, request: SynapseRequest): flows = [] if self.jwt_enabled: flows.append({"type": LoginRestServlet.JWT_TYPE}) + flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED}) if self.cas_enabled: # we advertise CAS for backwards compat, though MSC1721 renamed it @@ -131,20 +150,21 @@ class LoginRestServlet(RestServlet): return 200, {"flows": flows} - def on_OPTIONS(self, request): + def on_OPTIONS(self, request: SynapseRequest): return 200, {} - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest): self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) try: if self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE + or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - result = await self.do_jwt_login(login_submission) + result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - result = await self.do_token_login(login_submission) + result = await self._do_token_login(login_submission) else: result = await self._do_other_login(login_submission) except KeyError: @@ -155,14 +175,14 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - async def _do_other_login(self, login_submission): + async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins Args: login_submission: Returns: - dict: HTTP response + HTTP response """ # Log the request we got, but only certain fields to minimise the chance of # logging someone's password (even if they accidentally put it in the wrong @@ -195,11 +215,14 @@ class LoginRestServlet(RestServlet): if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) if medium == "email": - # For emails, transform the address to lowercase. - # We store all email addreses as lowercase in the DB. - # (See add_threepid in synapse/handlers/auth.py) - address = address.lower() + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. @@ -277,25 +300,30 @@ class LoginRestServlet(RestServlet): return result async def _complete_login( - self, user_id, login_submission, callback=None, create_non_existent_users=False - ): + self, + user_id: str, + login_submission: JsonDict, + callback: Optional[ + Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] + ] = None, + create_non_existent_users: bool = False, + ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on - all succesful logins. + all successful logins. - Applies the ratelimiting for succesful login attempts against an + Applies the ratelimiting for successful login attempts against an account. Args: - user_id (str): ID of the user to register. - login_submission (dict): Dictionary of login information. - callback (func|None): Callback function to run after registration. - create_non_existent_users (bool): Whether to create the user if - they don't exist. Defaults to False. + user_id: ID of the user to register. + login_submission: Dictionary of login information. + callback: Callback function to run after registration. + create_non_existent_users: Whether to create the user if they don't + exist. Defaults to False. Returns: - result (Dict[str,str]): Dictionary of account information after - successful registration. + result: Dictionary of account information after successful registration. """ # Before we actually log them in we check if they've already logged in @@ -329,7 +357,7 @@ class LoginRestServlet(RestServlet): return result - async def do_token_login(self, login_submission): + async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( @@ -339,28 +367,32 @@ class LoginRestServlet(RestServlet): result = await self._complete_login(user_id, login_submission) return result - async def do_jwt_login(self, login_submission): + async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED + 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN ) import jwt - from jwt.exceptions import InvalidTokenError try: payload = jwt.decode( - token, self.jwt_secret, algorithms=[self.jwt_algorithm] + token, + self.jwt_secret, + algorithms=[self.jwt_algorithm], + issuer=self.jwt_issuer, + audience=self.jwt_audiences, + ) + except jwt.PyJWTError as e: + # A JWT error occurred, return some info back to the client. + raise LoginError( + 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN, ) - except jwt.ExpiredSignatureError: - raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) - except InvalidTokenError: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) user = payload.get("sub", None) if user is None: - raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) + raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN) user_id = UserID(user, self.hs.hostname).to_string() result = await self._complete_login( diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eec16f8ad8..970fdd5834 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py
@@ -17,8 +17,6 @@ """ import logging -from six import string_types - from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -51,7 +49,9 @@ class PresenceStatusRestServlet(RestServlet): raise AuthError(403, "You are not allowed to see their presence.") state = await self.presence_handler.get_state(target_user=user) - state = format_user_presence_state(state, self.clock.time_msec()) + state = format_user_presence_state( + state, self.clock.time_msec(), include_user_id=False + ) return 200, state @@ -71,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet): if "status_msg" in content: state["status_msg"] = content.pop("status_msg") - if not isinstance(state["status_msg"], string_types): + if not isinstance(state["status_msg"], str): raise SynapseError(400, "status_msg must be a string.") if content: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 550a2f1b44..5f65cb7d83 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py
@@ -16,7 +16,7 @@ import logging from synapse.api.errors import Codes, StoreError, SynapseError -from synapse.http.server import finish_request +from synapse.http.server import respond_with_html_bytes from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -177,13 +177,9 @@ class PushersRemoveRestServlet(RestServlet): self.notifier.on_new_replication_data() - request.setResponseCode(200) - request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader( - b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + respond_with_html_bytes( + request, 200, PushersRemoveRestServlet.SUCCESS_HTML, ) - request.write(PushersRemoveRestServlet.SUCCESS_HTML) - finish_request(request) return None def on_OPTIONS(self, _): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 43b64608e7..1a3398316d 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -15,11 +15,11 @@ # limitations under the License. """ This module contains REST servlets to do with rooms: /rooms/<paths> """ + import logging import re from typing import List, Optional - -from six.moves.urllib import parse as urlparse +from urllib import parse as urlparse from canonicaljson import json @@ -218,10 +218,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) event_id = event.event_id - ret = {} # type: dict - if event_id: - set_tag("event_id", event_id) - ret = {"event_id": event_id} + set_tag("event_id", event_id) + ret = {"event_id": event_id} return 200, ret @@ -518,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args - filter_bytes = parse_string(request, b"filter", encoding=None) - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] if ( event_filter @@ -630,9 +628,9 @@ class RoomEventContextServlet(RestServlet): limit = parse_integer(request, "limit", default=10) # picking the API shape for symmetry with /messages - filter_bytes = parse_string(request, "filter") - if filter_bytes: - filter_json = urlparse.unquote(filter_bytes) + filter_str = parse_string(request, b"filter", encoding="utf-8") + if filter_str: + filter_json = urlparse.unquote(filter_str) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] else: event_filter = None @@ -820,9 +818,18 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() + # If we're not on the typing writer instance we should scream if we get + # requests. + self._is_typing_writer = ( + hs.config.worker.writers.typing == hs.get_instance_name() + ) + async def on_PUT(self, request, room_id, user_id): requester = await self.auth.get_user_by_req(request) + if not self._is_typing_writer: + raise Exception("Got /typing request on instance that is not typing writer") + room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 747d46eac2..50277c6cf6 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet): # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the # same result as the TURN server. - password = base64.b64encode(mac.digest()) + password = base64.b64encode(mac.digest()).decode("ascii") elif turnUris and turnUsername and turnPassword and userLifetime: username = turnUsername