diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index bf0f9bd077..64d5c58b65 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -26,8 +27,9 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -113,7 +115,7 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -141,10 +143,10 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- def on_OPTIONS(self, request):
+ def on_OPTIONS(self, request: SynapseRequest):
return 200, {}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
@@ -153,9 +155,9 @@ class LoginRestServlet(RestServlet):
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- result = await self.do_jwt_login(login_submission)
+ result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = await self.do_token_login(login_submission)
+ result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError:
@@ -166,14 +168,14 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- async def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
- dict: HTTP response
+ HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -206,11 +208,14 @@ class LoginRestServlet(RestServlet):
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- address = address.lower()
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
@@ -288,25 +293,30 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
- self, user_id, login_submission, callback=None, create_non_existent_users=False
- ):
+ self,
+ user_id: str,
+ login_submission: JsonDict,
+ callback: Optional[
+ Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
+ ] = None,
+ create_non_existent_users: bool = False,
+ ) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
- all succesful logins.
+ all successful logins.
- Applies the ratelimiting for succesful login attempts against an
+ Applies the ratelimiting for successful login attempts against an
account.
Args:
- user_id (str): ID of the user to register.
- login_submission (dict): Dictionary of login information.
- callback (func|None): Callback function to run after registration.
- create_non_existent_users (bool): Whether to create the user if
- they don't exist. Defaults to False.
+ user_id: ID of the user to register.
+ login_submission: Dictionary of login information.
+ callback: Callback function to run after registration.
+ create_non_existent_users: Whether to create the user if they don't
+ exist. Defaults to False.
Returns:
- result (Dict[str,str]): Dictionary of account information after
- successful registration.
+ result: Dictionary of account information after successful registration.
"""
# Before we actually log them in we check if they've already logged in
@@ -340,7 +350,7 @@ class LoginRestServlet(RestServlet):
return result
- async def do_token_login(self, login_submission):
+ async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -350,7 +360,7 @@ class LoginRestServlet(RestServlet):
result = await self._complete_login(user_id, login_submission)
return result
- async def do_jwt_login(self, login_submission):
+ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,6 +14,7 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -28,6 +29,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -63,11 +65,27 @@ class ProfileDisplaynameRestServlet(RestServlet):
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -76,6 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
@@ -114,11 +133,27 @@ class ProfileAvatarURLRestServlet(RestServlet):
user, requester, new_avatar_url, is_admin
)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
+
return 200, {}
def on_OPTIONS(self, request, user_id):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 46811abbfa..01b80b86fa 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -217,10 +217,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
event_id = event.event_id
- ret = {} # type: dict
- if event_id:
- set_tag("event_id", event_id)
- ret = {"event_id": event_id}
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -725,7 +723,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ new_room=False,
+ id_access_token=content.get("id_access_token"),
)
return 200, {}
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 747d46eac2..50277c6cf6 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet):
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
- password = base64.b64encode(mac.digest())
+ password = base64.b64encode(mac.digest()).decode("ascii")
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
|