diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 46e458e95b..87f927890c 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
+ shared_rooms,
sync,
tags,
thirdparty,
@@ -125,3 +126,6 @@ class ClientRestResource(JsonResource):
synapse.rest.admin.register_servlets_for_client_rest_resource(
hs, client_resource
)
+
+ # unstable
+ shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 9eda592de9..1c88c93f38 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -35,8 +35,10 @@ from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
+ DeleteRoomRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ RoomMembersRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
)
@@ -200,6 +202,8 @@ def register_servlets(hs, http_server):
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
+ RoomMembersRestServlet(hs).register(http_server)
+ DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 8173baef8f..09726d52d6 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from http import HTTPStatus
from typing import List, Optional
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -30,9 +31,8 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
historical_admin_path_patterns,
)
-from synapse.storage.data_stores.main.room import RoomSortOrder
+from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
-from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -46,20 +46,10 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
- DEFAULT_MESSAGE = (
- "Sharing illegal content on this server is not permitted and rooms in"
- " violation will be blocked."
- )
-
def __init__(self, hs):
self.hs = hs
- self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
- self._room_creation_handler = hs.get_room_creation_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- self._replication = hs.get_replication_data_handler()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@@ -67,116 +57,74 @@ class ShutdownRoomRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
- new_room_user_id = content["new_room_user_id"]
-
- room_creator_requester = create_requester(new_room_user_id)
- message = content.get("message", self.DEFAULT_MESSAGE)
- room_name = content.get("room_name", "Content Violation Notification")
-
- info, stream_id = await self._room_creation_handler.create_room(
- room_creator_requester,
- config={
- "preset": "public_chat",
- "name": room_name,
- "power_level_content_override": {"users_default": -10},
- },
- ratelimit=False,
+ ret = await self.room_shutdown_handler.shutdown_room(
+ room_id=room_id,
+ new_room_user_id=content["new_room_user_id"],
+ new_room_name=content.get("room_name"),
+ message=content.get("message"),
+ requester_user_id=requester.user.to_string(),
+ block=True,
)
- new_room_id = info["room_id"]
- requester_user_id = requester.user.to_string()
+ return (200, ret)
- logger.info(
- "Shutting down room %r, joining to new room: %r", room_id, new_room_id
- )
- # This will work even if the room is already blocked, but that is
- # desirable in case the first attempt at blocking the room failed below.
- await self.store.block_room(room_id, requester_user_id)
-
- # We now wait for the create room to come back in via replication so
- # that we can assume that all the joins/invites have propogated before
- # we try and auto join below.
- #
- # TODO: Currently the events stream is written to from master
- await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
- )
-
- users = await self.state.get_current_users_in_room(room_id)
- kicked_users = []
- failed_to_kick_users = []
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
- continue
+class DeleteRoomRestServlet(RestServlet):
+ """Delete a room from server. It is a combination and improvement of
+ shut down and purge room.
+ Shuts down a room by removing all local users from the room.
+ Blocking all future invites and joins to the room is optional.
+ If desired any local aliases will be repointed to a new room
+ created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ It will remove all trace of a room from the database.
+ """
- logger.info("Kicking %r from %r...", user_id, room_id)
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
- try:
- target_requester = create_requester(user_id)
- _, stream_id = await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=room_id,
- action=Membership.LEAVE,
- content={},
- ratelimit=False,
- require_consent=False,
- )
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
+ self.pagination_handler = hs.get_pagination_handler()
- # Wait for leave to come in over replication before trying to forget.
- await self._replication.wait_for_stream_position(
- self.hs.config.worker.writers.events, "events", stream_id
- )
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
- await self.room_member_handler.forget(target_requester.user, room_id)
+ content = parse_json_object_from_request(request)
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
+ block = content.get("block", False)
+ if not isinstance(block, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'block' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
- kicked_users.append(user_id)
- except Exception:
- logger.exception(
- "Failed to leave old room and join new room for %r", user_id
- )
- failed_to_kick_users.append(user_id)
-
- await self.event_creation_handler.create_and_send_nonmember_event(
- room_creator_requester,
- {
- "type": "m.room.message",
- "content": {"body": message, "msgtype": "m.text"},
- "room_id": new_room_id,
- "sender": new_room_user_id,
- },
- ratelimit=False,
- )
+ purge = content.get("purge", True)
+ if not isinstance(purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
- aliases_for_room = await maybe_awaitable(
- self.store.get_aliases_for_room(room_id)
+ ret = await self.room_shutdown_handler.shutdown_room(
+ room_id=room_id,
+ new_room_user_id=content.get("new_room_user_id"),
+ new_room_name=content.get("room_name"),
+ message=content.get("message"),
+ requester_user_id=requester.user.to_string(),
+ block=block,
)
- await self.store.update_aliases_for_room(
- room_id, new_room_id, requester_user_id
- )
+ # Purge room
+ if purge:
+ await self.pagination_handler.purge_room(room_id)
- return (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
+ return (200, ret)
class ListRoomRestServlet(RestServlet):
@@ -292,6 +240,31 @@ class RoomRestServlet(RestServlet):
return 200, ret
+class RoomMembersRestServlet(RestServlet):
+ """
+ Get members list of a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, room_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ ret = await self.store.get_room(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ members = await self.store.get_users_in_room(room_id)
+ ret = {"members": members, "total": len(members)}
+
+ return 200, ret
+
+
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
@@ -343,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet):
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
+ # update_membership with an action of "invite" can raise a
+ # ShadowBanError. This is not handled since it is assumed that
+ # an admin isn't going to call this API with a shadow-banned user.
await self.room_member_handler.update_membership(
requester=requester,
target=fake_requester.user,
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index fefc8f71fa..f3e77da850 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,9 +16,7 @@ import hashlib
import hmac
import logging
import re
-
-from six import text_type
-from six.moves import http_client
+from http import HTTPStatus
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -75,6 +73,7 @@ class UsersRestServletV2(RestServlet):
The parameters `from` and `limit` are required only for pagination.
By default, a `limit` of 100 is used.
The parameter `user_id` can be used to filter by user id.
+ The parameter `name` can be used to filter by user id or display name.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
"""
@@ -91,11 +90,12 @@ class UsersRestServletV2(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id", default=None)
+ name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
users, total = await self.store.get_users_paginate(
- start, limit, user_id, guests, deactivated
+ start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if len(users) >= limit:
@@ -215,10 +215,7 @@ class UserRestServletV2(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
if "password" in body:
- if (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
+ if not isinstance(body["password"], str) or len(body["password"]) > 512:
raise SynapseError(400, "Invalid password")
else:
new_password = body["password"]
@@ -244,6 +241,15 @@ class UserRestServletV2(RestServlet):
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
)
+ elif not deactivate and user["deactivated"]:
+ if "password" not in body:
+ raise SynapseError(
+ 400, "Must provide a password to re-activate an account."
+ )
+
+ await self.deactivate_account_handler.activate_account(
+ target_user.to_string()
+ )
user = await self.admin_handler.get_user(target_user)
return 200, user
@@ -252,14 +258,13 @@ class UserRestServletV2(RestServlet):
password = body.get("password")
password_hash = None
if password is not None:
- if not isinstance(password, text_type) or len(password) > 512:
+ if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
password_hash = await self.auth_handler.hash(password)
admin = body.get("admin", None)
user_type = body.get("user_type", None)
displayname = body.get("displayname", None)
- threepids = body.get("threepids", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
@@ -370,10 +375,7 @@ class UserRegisterServlet(RestServlet):
400, "username must be specified", errcode=Codes.BAD_JSON
)
else:
- if (
- not isinstance(body["username"], text_type)
- or len(body["username"]) > 512
- ):
+ if not isinstance(body["username"], str) or len(body["username"]) > 512:
raise SynapseError(400, "Invalid username")
username = body["username"].encode("utf-8")
@@ -386,7 +388,7 @@ class UserRegisterServlet(RestServlet):
)
else:
password = body["password"]
- if not isinstance(password, text_type) or len(password) > 512:
+ if not isinstance(password, str) or len(password) > 512:
raise SynapseError(400, "Invalid password")
password_bytes = password.encode("utf-8")
@@ -477,7 +479,7 @@ class DeactivateAccountRestServlet(RestServlet):
erase = body.get("erase", False)
if not isinstance(erase, bool):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'erase' must be a boolean, if given",
Codes.BAD_JSON,
)
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 6da71dc46f..7be5c0fb88 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
-class HttpTransactionCache(object):
+class HttpTransactionCache:
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 5934b1fe8b..b210015173 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
try:
- service = await self.auth.get_appservice_by_req(request)
+ service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index dceb2792fa..a14618ac84 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,9 +14,14 @@
# limitations under the License.
import logging
+from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.handlers.auth import (
+ convert_client_dict_legacy_fields_to_identifier,
+ login_id_phone_to_thirdparty,
+)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -26,64 +31,36 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID
-from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.types import JsonDict, UserID
+from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
-def login_submission_legacy_convert(submission):
- """
- If the input login submission is an old style object
- (ie. with top-level user / medium / address) convert it
- to a typed object.
- """
- if "user" in submission:
- submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
- del submission["user"]
-
- if "medium" in submission and "address" in submission:
- submission["identifier"] = {
- "type": "m.id.thirdparty",
- "medium": submission["medium"],
- "address": submission["address"],
- }
- del submission["medium"]
- del submission["address"]
-
-
-def login_id_thirdparty_from_phone(identifier):
- """
- Convert a phone login identifier type to a generic threepid identifier
- Args:
- identifier(dict): Login identifier dict of type 'm.id.phone'
-
- Returns: Login identifier dict of type 'm.id.threepid'
- """
- if "country" not in identifier or "number" not in identifier:
- raise SynapseError(400, "Invalid phone-type identifier")
-
- msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
-
- return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
-
-
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
- JWT_TYPE = "m.login.jwt"
+ JWT_TYPE = "org.matrix.login.jwt"
+ JWT_TYPE_DEPRECATED = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
self.hs = hs
+
+ # JWT configuration variables.
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.jwt_issuer = hs.config.jwt_issuer
+ self.jwt_audiences = hs.config.jwt_audiences
+
+ # SSO configuration.
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -104,10 +81,11 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
- def on_GET(self, request):
+ def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
+ flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
@@ -131,20 +109,21 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
- def on_OPTIONS(self, request):
+ def on_OPTIONS(self, request: SynapseRequest):
return 200, {}
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
try:
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
+ or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
- result = await self.do_jwt_login(login_submission)
+ result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = await self.do_token_login(login_submission)
+ result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError:
@@ -155,14 +134,14 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- async def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
- dict: HTTP response
+ HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -174,18 +153,11 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
- login_submission_legacy_convert(login_submission)
-
- if "identifier" not in login_submission:
- raise SynapseError(400, "Missing param: identifier")
-
- identifier = login_submission["identifier"]
- if "type" not in identifier:
- raise SynapseError(400, "Login identifier has no type")
+ identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
- identifier = login_id_thirdparty_from_phone(identifier)
+ identifier = login_id_phone_to_thirdparty(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
@@ -195,11 +167,14 @@ class LoginRestServlet(RestServlet):
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- address = address.lower()
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
@@ -277,25 +252,30 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
- self, user_id, login_submission, callback=None, create_non_existent_users=False
- ):
+ self,
+ user_id: str,
+ login_submission: JsonDict,
+ callback: Optional[
+ Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
+ ] = None,
+ create_non_existent_users: bool = False,
+ ) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
- all succesful logins.
+ all successful logins.
- Applies the ratelimiting for succesful login attempts against an
+ Applies the ratelimiting for successful login attempts against an
account.
Args:
- user_id (str): ID of the user to register.
- login_submission (dict): Dictionary of login information.
- callback (func|None): Callback function to run after registration.
- create_non_existent_users (bool): Whether to create the user if
- they don't exist. Defaults to False.
+ user_id: ID of the user to register.
+ login_submission: Dictionary of login information.
+ callback: Callback function to run after registration.
+ create_non_existent_users: Whether to create the user if they don't
+ exist. Defaults to False.
Returns:
- result (Dict[str,str]): Dictionary of account information after
- successful registration.
+ result: Dictionary of account information after successful registration.
"""
# Before we actually log them in we check if they've already logged in
@@ -329,7 +309,7 @@ class LoginRestServlet(RestServlet):
return result
- async def do_token_login(self, login_submission):
+ async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -339,28 +319,32 @@ class LoginRestServlet(RestServlet):
result = await self._complete_login(user_id, login_submission)
return result
- async def do_jwt_login(self, login_submission):
+ async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
+ 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
import jwt
- from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(
- token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ token,
+ self.jwt_secret,
+ algorithms=[self.jwt_algorithm],
+ issuer=self.jwt_issuer,
+ audience=self.jwt_audiences,
+ )
+ except jwt.PyJWTError as e:
+ # A JWT error occurred, return some info back to the client.
+ raise LoginError(
+ 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
)
- except jwt.ExpiredSignatureError:
- raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
- except InvalidTokenError:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload.get("sub", None)
if user is None:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
+ raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index eec16f8ad8..970fdd5834 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -17,8 +17,6 @@
"""
import logging
-from six import string_types
-
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -51,7 +49,9 @@ class PresenceStatusRestServlet(RestServlet):
raise AuthError(403, "You are not allowed to see their presence.")
state = await self.presence_handler.get_state(target_user=user)
- state = format_user_presence_state(state, self.clock.time_msec())
+ state = format_user_presence_state(
+ state, self.clock.time_msec(), include_user_id=False
+ )
return 200, state
@@ -71,7 +71,7 @@ class PresenceStatusRestServlet(RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
- if not isinstance(state["status_msg"], string_types):
+ if not isinstance(state["status_msg"], str):
raise SynapseError(400, "status_msg must be a string.")
if content:
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9fd4908136..e781a3bcf4 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.api.errors import (
NotFoundError,
StoreError,
@@ -25,7 +24,7 @@ from synapse.http.servlet import (
parse_json_value_from_request,
parse_string,
)
-from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -45,6 +44,8 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
+ self._users_new_default_push_rules = hs.config.users_new_default_push_rules
+
async def on_PUT(self, request, path):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -158,10 +159,10 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
def notify_user(self, user_id):
- stream_id, _ = self.store.get_push_rules_stream_token()
+ stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
- def set_rule_attr(self, user_id, spec, val):
+ async def set_rule_attr(self, user_id, spec, val):
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -171,7 +172,9 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+ return await self.store.set_push_rule_enabled(
+ user_id, namespaced_rule_id, val
+ )
elif spec["attr"] == "actions":
actions = val.get("actions")
_check_actions(actions)
@@ -179,9 +182,14 @@ class PushRuleRestServlet(RestServlet):
rule_id = spec["rule_id"]
is_default_rule = rule_id.startswith(".")
if is_default_rule:
- if namespaced_rule_id not in BASE_RULE_IDS:
+ if user_id in self._users_new_default_push_rules:
+ rule_ids = NEW_RULE_IDS
+ else:
+ rule_ids = BASE_RULE_IDS
+
+ if namespaced_rule_id not in rule_ids:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
- return self.store.set_push_rule_actions(
+ return await self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule
)
else:
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 550a2f1b44..5f65cb7d83 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -16,7 +16,7 @@
import logging
from synapse.api.errors import Codes, StoreError, SynapseError
-from synapse.http.server import finish_request
+from synapse.http.server import respond_with_html_bytes
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -177,13 +177,9 @@ class PushersRemoveRestServlet(RestServlet):
self.notifier.on_new_replication_data()
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(
- b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),)
+ respond_with_html_bytes(
+ request, 200, PushersRemoveRestServlet.SUCCESS_HTML,
)
- request.write(PushersRemoveRestServlet.SUCCESS_HTML)
- finish_request(request)
return None
def on_OPTIONS(self, _):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 105e0cf4d2..84baf3d59b 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -15,13 +15,11 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
+
import logging
import re
from typing import List, Optional
-
-from six.moves.urllib import parse as urlparse
-
-from canonicaljson import json
+from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
@@ -29,6 +27,7 @@ from synapse.api.errors import (
Codes,
HttpResponseException,
InvalidClientCredentialsError,
+ ShadowBanError,
SynapseError,
)
from synapse.api.filtering import Filter
@@ -46,6 +45,8 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.util import json_decoder
+from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
@@ -170,7 +171,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
room_id=room_id,
event_type=event_type,
state_key=state_key,
- is_guest=requester.is_guest,
)
if not data:
@@ -200,28 +200,29 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
- if event_type == EventTypes.Member:
- membership = content.get("membership", None)
- event_id, _ = await self.room_member_handler.update_membership(
- requester,
- target=UserID.from_string(state_key),
- room_id=room_id,
- action=membership,
- content=content,
- )
- else:
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, txn_id=txn_id
- )
- event_id = event.event_id
+ try:
+ if event_type == EventTypes.Member:
+ membership = content.get("membership", None)
+ event_id, _ = await self.room_member_handler.update_membership(
+ requester,
+ target=UserID.from_string(state_key),
+ room_id=room_id,
+ action=membership,
+ content=content,
+ )
+ else:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- ret = {} # type: dict
- if event_id:
- set_tag("event_id", event_id)
- ret = {"event_id": event_id}
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -251,12 +252,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, txn_id=txn_id
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- set_tag("event_id", event.event_id)
- return 200, {"event_id": event.event_id}
+ set_tag("event_id", event_id)
+ return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
@@ -446,7 +454,7 @@ class RoomMemberListRestServlet(RestServlet):
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -518,10 +526,12 @@ class RoomMessageListRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
- filter_bytes = parse_string(request, b"filter", encoding=None)
- if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
- event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@@ -630,10 +640,12 @@ class RoomEventContextServlet(RestServlet):
limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages
- filter_bytes = parse_string(request, "filter")
- if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes)
- event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
else:
event_filter = None
@@ -718,16 +730,20 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- await self.room_member_handler.do_3pid_invite(
- room_id,
- requester.user,
- content["medium"],
- content["address"],
- content["id_server"],
- requester,
- txn_id,
- content.get("id_access_token"),
- )
+ try:
+ await self.room_member_handler.do_3pid_invite(
+ room_id,
+ requester.user,
+ content["medium"],
+ content["address"],
+ content["id_server"],
+ requester,
+ txn_id,
+ content.get("id_access_token"),
+ )
+ except ShadowBanError:
+ # Pretend the request succeeded.
+ pass
return 200, {}
target = requester.user
@@ -739,15 +755,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content:
event_content = {"reason": content["reason"]}
- await self.room_member_handler.update_membership(
- requester=requester,
- target=target,
- room_id=room_id,
- action=membership_action,
- txn_id=txn_id,
- third_party_signed=content.get("third_party_signed", None),
- content=event_content,
- )
+ try:
+ await self.room_member_handler.update_membership(
+ requester=requester,
+ target=target,
+ room_id=room_id,
+ action=membership_action,
+ txn_id=txn_id,
+ third_party_signed=content.get("third_party_signed", None),
+ content=event_content,
+ )
+ except ShadowBanError:
+ # Pretend the request succeeded.
+ pass
return_value = {}
@@ -785,20 +805,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Redaction,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "redacts": event_id,
- },
- txn_id=txn_id,
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "redacts": event_id,
+ },
+ txn_id=txn_id,
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- set_tag("event_id", event.event_id)
- return 200, {"event_id": event.event_id}
+ set_tag("event_id", event_id)
+ return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)
@@ -819,9 +846,18 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
+ # If we're not on the typing writer instance we should scream if we get
+ # requests.
+ self._is_typing_writer = (
+ hs.config.worker.writers.typing == hs.get_instance_name()
+ )
+
async def on_PUT(self, request, room_id, user_id):
requester = await self.auth.get_user_by_req(request)
+ if not self._is_typing_writer:
+ raise Exception("Got /typing request on instance that is not typing writer")
+
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
@@ -832,17 +868,21 @@ class RoomTypingRestServlet(RestServlet):
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
- if content["typing"]:
- await self.typing_handler.started_typing(
- target_user=target_user,
- auth_user=requester.user,
- room_id=room_id,
- timeout=timeout,
- )
- else:
- await self.typing_handler.stopped_typing(
- target_user=target_user, auth_user=requester.user, room_id=room_id
- )
+ try:
+ if content["typing"]:
+ await self.typing_handler.started_typing(
+ target_user=target_user,
+ requester=requester,
+ room_id=room_id,
+ timeout=timeout,
+ )
+ else:
+ await self.typing_handler.stopped_typing(
+ target_user=target_user, requester=requester, room_id=room_id
+ )
+ except ShadowBanError:
+ # Pretend this worked without error.
+ pass
return 200, {}
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 747d46eac2..50277c6cf6 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet):
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
- password = base64.b64encode(mac.digest())
+ password = base64.b64encode(mac.digest()).decode("ascii")
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index bc11b4dda4..f016b4f1bd 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -17,24 +17,32 @@
"""
import logging
import re
-
-from twisted.internet import defer
+from typing import Iterable, Pattern
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
-def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
+def client_patterns(
+ path_regex: str,
+ releases: Iterable[int] = (0,),
+ unstable: bool = True,
+ v1: bool = False,
+) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
- path_regex (str): The regex string to match. This should NOT have a ^
+ path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
+ releases: An iterable of releases to include this endpoint under.
+ unstable: If true, include this endpoint under the "unstable" prefix.
+ v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
- SRE_Pattern
+ An iterable of patterns.
"""
patterns = []
@@ -51,7 +59,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
return patterns
-def set_timeline_upper_limit(filter_json, filter_timeline_limit):
+def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:
+ """
+ Enforces a maximum limit of a timeline query.
+
+ Params:
+ filter_json: The timeline query to modify.
+ filter_timeline_limit: The maximum limit to allow, passing -1 will
+ disable enforcing a maximum limit.
+ """
if filter_timeline_limit < 0:
return # no upper limits
timeline = filter_json.get("room", {}).get("timeline", {})
@@ -64,34 +80,22 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
- Takes a on_POST method which returns a deferred (errcode, body) response
+ Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
# ...
- yield self.auth_handler.check_auth
- """
+ await self.auth_handler.check_auth
+ """
- def wrapped(*args, **kwargs):
- res = defer.ensureDeferred(orig(*args, **kwargs))
- res.addErrback(_catch_incomplete_interactive_auth)
- return res
+ async def wrapped(*args, **kwargs):
+ try:
+ return await orig(*args, **kwargs)
+ except InteractiveAuthIncompleteError as e:
+ return 401, e.result
return wrapped
-
-
-def _catch_incomplete_interactive_auth(f):
- """helper for interactive_auth_handler
-
- Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
-
- Args:
- f (failure.Failure):
- """
- f.trap(InteractiveAuthIncompleteError)
- return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 8f9440da9a..cad3f9bbb7 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -15,23 +15,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-
-from six.moves import http_client
+import random
+from http import HTTPStatus
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.api.errors import (
+ Codes,
+ InteractiveAuthIncompleteError,
+ SynapseError,
+ ThreepidValidationError,
+)
from synapse.config.emailconfig import ThreepidBehaviour
-from synapse.http.server import finish_request
+from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import Mailer, load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -49,21 +54,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_password_reset_template_html,
- self.config.email_password_reset_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_password_reset_template_html,
+ template_text=self.config.email_password_reset_template_text,
)
async def on_POST(self, request):
@@ -84,7 +79,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database. This allows the user to reset his password without having to
+ # know the exact spelling (eg. upper and lower case) of address in the database.
+ # Stored in the database "foo@bar.com"
+ # User requests with "FOO@bar.com" would raise a Not Found error
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -95,6 +98,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # The email will be sent to the stored address.
+ # This avoids a potential account hijack by requesting a password reset to
+ # an email address which is controlled by the attacker but which, after
+ # canonicalisation, matches the one in our database.
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -103,6 +110,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -153,9 +163,8 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_password_reset_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_password_reset_template_failure_html
)
async def on_GET(self, request, medium):
@@ -198,17 +207,16 @@ class PasswordResetSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
- html = self.config.email_password_reset_template_success_html
- request.setResponseCode(200)
+ html = self.config.email_password_reset_template_success_html_content
+ status_code = 200
except ThreepidValidationError as e:
- request.setResponseCode(e.code)
+ status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
- request.write(html.encode("utf-8"))
- finish_request(request)
+ respond_with_html(request, status_code, html)
class PasswordRestServlet(RestServlet):
@@ -229,18 +237,12 @@ class PasswordRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- if "new_password" in body:
- new_password = body.pop("new_password")
+ new_password = body.pop("new_password", None)
+ if new_password is not None:
if not isinstance(new_password, str) or len(new_password) > 512:
raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "new_password_hash" in body:
- raise SynapseError(400, "Unexpected property: new_password_hash")
- body["new_password_hash"] = await self.auth_handler.hash(new_password)
-
# there are two possibilities here. Either the user does not have an
# access token, and needs to do a password reset; or they have one and
# need to validate their identity.
@@ -253,33 +255,62 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
- params = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ try:
+ params, session_id = await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
user_id = requester.user.to_string()
else:
requester = None
- result, params, _ = await self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY]],
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
- )
+ try:
+ result, params, session_id = await self.auth_handler.check_ui_auth(
+ [[LoginType.EMAIL_IDENTITY]],
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "modify your account password",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth, but
+ # they're not required to provide the password again.
+ #
+ # If a password is available now, hash the provided password and
+ # store it for later.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if "medium" not in threepid or "address" not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid["medium"] == "email":
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
- threepid["address"] = threepid["address"].lower()
+ try:
+ threepid["address"] = canonicalise_email(threepid["address"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
# if using email, we must know about the email they're authing with!
threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"]
@@ -291,12 +322,21 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- assert_params_in_dict(params, ["new_password_hash"])
- new_password_hash = params["new_password_hash"]
+ # If we have a password in this request, prefer it. Otherwise, there
+ # must be a password hash from an earlier request.
+ if new_password:
+ password_hash = await self.auth_handler.hash(new_password)
+ else:
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
+
logout_devices = params.get("logout_devices", True)
await self._set_password_handler.set_password(
- user_id, new_password_hash, logout_devices, requester
+ user_id, password_hash, logout_devices, requester
)
return 200, {}
@@ -321,7 +361,7 @@ class DeactivateAccountRestServlet(RestServlet):
erase = body.get("erase", False)
if not isinstance(erase, bool):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'erase' must be a boolean, if given",
Codes.BAD_JSON,
)
@@ -364,19 +404,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_add_threepid_template_html,
- self.config.email_add_threepid_template_text,
- ],
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_add_threepid_template_html,
+ template_text=self.config.email_add_threepid_template_text,
)
async def on_POST(self, request):
@@ -394,7 +426,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database.
+ # This ensures that the validation email is sent to the canonicalised address
+ # as it will later be entered into the database.
+ # Otherwise the email will be sent to "FOO@bar.com" and stored as
+ # "foo@bar.com" in database.
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -405,14 +446,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = await self.store.get_user_id_by_threepid(
- "email", body["email"]
- )
+ existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -481,6 +523,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -524,9 +569,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_add_threepid_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_add_threepid_template_failure_html
)
async def on_GET(self, request):
@@ -571,16 +615,15 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Otherwise show the success template
html = self.config.email_add_threepid_template_success_html_content
- request.setResponseCode(200)
+ status_code = 200
except ThreepidValidationError as e:
- request.setResponseCode(e.code)
+ status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
- request.write(html.encode("utf-8"))
- finish_request(request)
+ respond_with_html(request, status_code, html)
class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 2f10fa64e2..d06336ceea 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -16,7 +16,7 @@
import logging
from synapse.api.errors import AuthError, SynapseError
-from synapse.http.server import finish_request
+from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet
from ._base import client_patterns
@@ -26,9 +26,6 @@ logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- SUCCESS_HTML = (
- b"<html><body>Your account has been successfully renewed.</body><html>"
- )
def __init__(self, hs):
"""
@@ -59,11 +56,7 @@ class AccountValidityRenewServlet(RestServlet):
status_code = 404
response = self.failure_html
- request.setResponseCode(status_code)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(response),))
- request.write(response.encode("utf8"))
- finish_request(request)
+ respond_with_html(request, status_code, response)
class AccountValiditySendMailServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 75590ebaeb..8e585e9153 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -18,7 +18,7 @@ import logging
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.http.server import finish_request
+from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string
from ._base import client_patterns
@@ -200,13 +200,7 @@ class AuthRestServlet(RestServlet):
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
+ respond_with_html(request, 200, html)
return None
async def on_POST(self, request, stagetype):
@@ -263,13 +257,7 @@ class AuthRestServlet(RestServlet):
raise SynapseError(404, "Unknown auth stage type")
# Render the HTML and return.
- html_bytes = html.encode("utf8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
+ respond_with_html(request, 200, html)
return None
def on_OPTIONS(self, _):
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 1efe60f3a7..075afdd32b 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index e2efc47024..ae1a8c4e6c 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -16,16 +16,16 @@
import hmac
import logging
+import random
from typing import List, Union
-from six import string_types
-
import synapse
import synapse.api.auth
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
+ InteractiveAuthIncompleteError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
@@ -38,18 +38,18 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
-from synapse.http.server import finish_request
+from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
+from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
-from synapse.util.threepids import check_3pid_allowed
+from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -82,23 +82,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- from synapse.push.mailer import Mailer, load_jinja2_templates
-
- template_html, template_text = load_jinja2_templates(
- self.config.email_template_dir,
- [
- self.config.email_registration_template_html,
- self.config.email_registration_template_text,
- ],
- apply_format_ts_filter=True,
- apply_mxc_to_http_filter=True,
- public_baseurl=self.config.public_baseurl,
- )
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=template_html,
- template_text=template_text,
+ template_html=self.config.email_registration_template_html,
+ template_text=self.config.email_registration_template_text,
)
async def on_POST(self, request):
@@ -118,7 +106,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
- email = body["email"]
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See on_POST in EmailThreepidRequestTokenRestServlet
+ # in synapse/rest/client/v2_alpha/account.py)
+ try:
+ email = canonicalise_email(body["email"])
+ except ValueError as e:
+ raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -130,13 +125,16 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- "email", body["email"]
+ "email", email
)
if existing_user_id is not None:
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -209,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
+ # Also wait for some random amount of time between 100ms and 1s to make it
+ # look like we did something.
+ await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
@@ -256,15 +257,8 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
- )
-
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- (self.failure_email_template,) = load_jinja2_templates(
- self.config.email_template_dir,
- [self.config.email_registration_template_failure_html],
+ self._failure_email_template = (
+ self.config.email_registration_template_failure_html
)
async def on_GET(self, request, medium):
@@ -306,17 +300,15 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Otherwise show the success template
html = self.config.email_registration_template_success_html_content
-
- request.setResponseCode(200)
+ status_code = 200
except ThreepidValidationError as e:
- request.setResponseCode(e.code)
+ status_code = e.code
# Show a failure page with a reason
template_vars = {"failure_reason": e.msg}
- html = self.failure_email_template.render(**template_vars)
+ html = self._failure_email_template.render(**template_vars)
- request.write(html.encode("utf-8"))
- finish_request(request)
+ respond_with_html(request, status_code, html)
class UsernameAvailabilityRestServlet(RestServlet):
@@ -384,6 +376,7 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_enabled = self.hs.config.enable_registration
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
@@ -409,32 +402,17 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
- # we do basic sanity checks here because the auth layer will store these
- # in sessions. Pull out the username/password provided to us.
- if "password" in body:
- password = body.pop("password")
- if not isinstance(password, string_types) or len(password) > 512:
- raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(password)
-
- # If the password is valid, hash it and store it back on the body.
- # This ensures that only the hashed password is handled everywhere.
- if "password_hash" in body:
- raise SynapseError(400, "Unexpected property: password_hash")
- body["password_hash"] = await self.auth_handler.hash(password)
-
+ # Pull out the provided username and do basic sanity checks early since
+ # the auth layer will store these in sessions.
desired_username = None
if "username" in body:
- if (
- not isinstance(body["username"], string_types)
- or len(body["username"]) > 512
- ):
+ if not isinstance(body["username"], str) or len(body["username"]) > 512:
raise SynapseError(400, "Invalid username")
desired_username = body["username"]
appservice = None
if self.auth.has_access_token(request):
- appservice = await self.auth.get_appservice_by_req(request)
+ appservice = self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -453,28 +431,41 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
- if isinstance(desired_username, string_types):
+ if isinstance(desired_username, str):
result = await self._do_appservice_registration(
desired_username, access_token, body
)
return 200, result # we throw for non 200 responses
- # for regular registration, downcase the provided username before
- # attempting to register it. This should mean
- # that people who try to register with upper-case in their usernames
- # don't get a nasty surprise. (Note that we treat username
- # case-insenstively in login, so they are free to carry on imagining
- # that their username is CrAzYh4cKeR if that keeps them happy)
- if desired_username is not None:
- desired_username = desired_username.lower()
-
# == Normal User Registration == (everyone else)
- if not self.hs.config.enable_registration:
+ if not self._registration_enabled:
raise SynapseError(403, "Registration has been disabled")
+ # For regular registration, convert the provided username to lowercase
+ # before attempting to register it. This should mean that people who try
+ # to register with upper-case in their usernames don't get a nasty surprise.
+ #
+ # Note that we treat usernames case-insensitively in login, so they are
+ # free to carry on imagining that their username is CrAzYh4cKeR if that
+ # keeps them happy.
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
+ # Check if this account is upgrading from a guest account.
guest_access_token = body.get("guest_access_token", None)
- if "initial_device_display_name" in body and "password_hash" not in body:
+ # Pull out the provided password and do basic sanity checks early.
+ #
+ # Note that we remove the password from the body since the auth layer
+ # will store the body in the session and we don't want a plaintext
+ # password store there.
+ password = body.pop("password", None)
+ if password is not None:
+ if not isinstance(password, str) or len(password) > 512:
+ raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(password)
+
+ if "initial_device_display_name" in body and password is None:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
@@ -484,6 +475,7 @@ class RegisterRestServlet(RestServlet):
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
+ password_hash = None
if session_id:
# if we get a registered user id out of here, it means we previously
# registered a user for this session, so we could just return the
@@ -492,7 +484,12 @@ class RegisterRestServlet(RestServlet):
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
+ # Extract the previously-hashed password from the session.
+ password_hash = await self.auth_handler.get_session_data(
+ session_id, "password_hash", None
+ )
+ # Ensure that the username is valid.
if desired_username is not None:
await self.registration_handler.check_username(
desired_username,
@@ -500,20 +497,38 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
- auth_result, params, session_id = await self.auth_handler.check_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
- )
+ # Check if the user-interactive authentication flows are complete, if
+ # not this will raise a user-interactive auth error.
+ try:
+ auth_result, params, session_id = await self.auth_handler.check_ui_auth(
+ self._registration_flows,
+ request,
+ body,
+ self.hs.get_ip_from_request(request),
+ "register a new account",
+ )
+ except InteractiveAuthIncompleteError as e:
+ # The user needs to provide more steps to complete auth.
+ #
+ # Hash the password and store it with the session since the client
+ # is not required to provide the password again.
+ #
+ # If a password hash was previously stored we will not attempt to
+ # re-hash and store it for efficiency. This assumes the password
+ # does not change throughout the authentication flow, but this
+ # should be fine since the data is meant to be consistent.
+ if not password_hash and password:
+ password_hash = await self.auth_handler.hash(password)
+ await self.auth_handler.set_session_data(
+ e.session_id, "password_hash", password_hash
+ )
+ raise
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
-
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
@@ -535,12 +550,15 @@ class RegisterRestServlet(RestServlet):
# don't re-register the threepids
registered = False
else:
- # NB: This may be from the auth handler and NOT from the POST
- assert_params_in_dict(params, ["password_hash"])
+ # If we have a password in this request, prefer it. Otherwise, there
+ # might be a password hash from an earlier request.
+ if password:
+ password_hash = await self.auth_handler.hash(password)
+ if not password_hash:
+ raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
desired_username = params.get("username", None)
guest_access_token = params.get("guest_access_token", None)
- new_password_hash = params.get("password_hash", None)
if desired_username is not None:
desired_username = desired_username.lower()
@@ -559,6 +577,15 @@ class RegisterRestServlet(RestServlet):
if login_type in auth_result:
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See on_POST in EmailThreepidRequestTokenRestServlet
+ # in synapse/rest/client/v2_alpha/account.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
existing_user_id = await self.store.get_user_id_by_threepid(
medium, address
@@ -571,12 +598,17 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
+ entries = await self.store.get_user_agents_ips_to_ui_auth_session(
+ session_id
+ )
+
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
- password_hash=new_password_hash,
+ password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
+ user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
@@ -586,8 +618,8 @@ class RegisterRestServlet(RestServlet):
):
await self.store.upsert_monthly_active_user(registered_user_id)
- # remember that we've now registered that user account, and with
- # what user ID (since the user may not have specified)
+ # Remember that the user account has been registered (and the user
+ # ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
@@ -626,7 +658,7 @@ class RegisterRestServlet(RestServlet):
(object) params: registration parameters, from which we pull
device_id, initial_device_name and inhibit_login
Returns:
- defer.Deferred: (object) dictionary for response from /register
+ dictionary for response from /register
"""
result = {"user_id": user_id, "home_server": self.hs.hostname}
if not params.get("inhibit_login", False):
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 89002ffbff..e29f49f7f5 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
import logging
from synapse.api.constants import EventTypes, RelationTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.util.stringutils import random_string
from ._base import client_patterns
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict=event_dict, txn_id=txn_id
- )
+ try:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict=event_dict, txn_id=txn_id
+ )
+ event_id = event.event_id
+ except ShadowBanError:
+ event_id = "$" + random_string(43)
- return 200, {"event_id": event.event_id}
+ return 200, {"event_id": event_id}
class RelationPaginationServlet(RestServlet):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index f067b5edac..e15927c4ea 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -14,9 +14,7 @@
# limitations under the License.
import logging
-
-from six import string_types
-from six.moves import http_client
+from http import HTTPStatus
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
@@ -47,15 +45,15 @@ class ReportEventRestServlet(RestServlet):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("reason", "score"))
- if not isinstance(body["reason"], string_types):
+ if not isinstance(body["reason"], str):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'reason' must be a string",
Codes.BAD_JSON,
)
if not isinstance(body["score"], int):
raise SynapseError(
- http_client.BAD_REQUEST,
+ HTTPStatus.BAD_REQUEST,
"Param 'score' must be an integer",
Codes.BAD_JSON,
)
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index f357015a70..39a5518614 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,13 +15,14 @@
import logging
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.util import stringutils
from ._base import client_patterns
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
- new_version = content["new_version"]
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
if new_version is None:
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION,
)
- new_room_id = await self._room_creation_handler.upgrade_room(
- requester, room_id, new_version
- )
+ try:
+ new_room_id = await self._room_creation_handler.upgrade_room(
+ requester, room_id, new_version
+ )
+ except ShadowBanError:
+ # Generate a random room ID.
+ new_room_id = stringutils.random_string(18)
ret = {"replacement_room": new_room_id}
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
new file mode 100644
index 0000000000..2492634dac
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Half-Shot
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class UserSharedRoomsServlet(RestServlet):
+ """
+ GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
+ """
+
+ PATTERNS = client_patterns(
+ "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
+ releases=(), # This is an unstable feature
+ )
+
+ def __init__(self, hs):
+ super(UserSharedRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.user_directory_active = hs.config.update_user_directory
+
+ async def on_GET(self, request, user_id):
+
+ if not self.user_directory_active:
+ raise SynapseError(
+ code=400,
+ msg="The user directory is disabled on this server. Cannot determine shared rooms.",
+ errcode=Codes.FORBIDDEN,
+ )
+
+ UserID.from_string(user_id)
+
+ requester = await self.auth.get_user_by_req(request)
+ if user_id == requester.user.to_string():
+ raise SynapseError(
+ code=400,
+ msg="You cannot request a list of shared rooms with yourself",
+ errcode=Codes.FORBIDDEN,
+ )
+ rooms = await self.store.get_shared_rooms_for_users(
+ requester.user.to_string(), user_id
+ )
+
+ return 200, {"joined": list(rooms)}
+
+
+def register_servlets(hs, http_server):
+ UserSharedRoomsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 8fa68dd37f..a0b00135e1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -16,8 +16,6 @@
import itertools
import logging
-from canonicaljson import json
-
from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
+from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
- filter_object = json.loads(filter_id)
+ filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
@@ -178,14 +177,22 @@ class SyncRestServlet(RestServlet):
full_state=full_state,
)
+ # the client may have disconnected by now; don't bother to serialize the
+ # response if so.
+ if request._disconnected:
+ logger.info("Client has disconnected; not serializing response.")
+ return 200, {}
+
time_now = self.clock.time_msec()
response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
+ logger.debug("Event formatting complete")
return 200, response_content
async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ logger.debug("Formatting events in sync response")
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == "federation":
@@ -213,6 +220,7 @@ class SyncRestServlet(RestServlet):
event_formatter,
)
+ logger.debug("building sync response dict")
return {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
@@ -417,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
+ result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0d668df0b6..24ac57f35d 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
+ # Implements additional endpoints as described in MSC2666
+ "uk.half-shot.msc2666": True,
},
},
)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 1ddf9997ff..b3e4d5612e 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -16,22 +16,15 @@
import hmac
import logging
from hashlib import sha256
+from http import HTTPStatus
from os import path
-from six.moves import http_client
-
import jinja2
from jinja2 import TemplateNotFound
-from twisted.internet import defer
-
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
-from synapse.http.server import (
- DirectServeResource,
- finish_request,
- wrap_html_request_handler,
-)
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_string
from synapse.types import UserID
@@ -49,7 +42,7 @@ else:
return a == b
-class ConsentResource(DirectServeResource):
+class ConsentResource(DirectServeHtmlResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
@@ -120,7 +113,6 @@ class ConsentResource(DirectServeResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- @wrap_html_request_handler
async def _async_render_GET(self, request):
"""
Args:
@@ -141,7 +133,7 @@ class ConsentResource(DirectServeResource):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id)
+ u = await self.store.get_user_by_id(qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")
@@ -161,7 +153,6 @@ class ConsentResource(DirectServeResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- @wrap_html_request_handler
async def _async_render_POST(self, request):
"""
Args:
@@ -197,12 +188,8 @@ class ConsentResource(DirectServeResource):
template_html = self._jinja_env.get_template(
path.join(TEMPLATE_LANGUAGE, template_name)
)
- html_bytes = template_html.render(**template_args).encode("utf8")
-
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
- request.write(html_bytes)
- finish_request(request)
+ html = template_html.render(**template_args)
+ respond_with_html(request, 200, html)
def _check_hash(self, userid, userhmac):
"""
@@ -223,4 +210,4 @@ class ConsentResource(DirectServeResource):
)
if not compare_digest(want_mac, userhmac):
- raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
+ raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/health.py b/synapse/rest/health.py
new file mode 100644
index 0000000000..0170950bf3
--- /dev/null
+++ b/synapse/rest/health.py
@@ -0,0 +1,31 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+
+class HealthResource(Resource):
+ """A resource that does nothing except return a 200 with a body of `OK`,
+ which can be used as a health check.
+
+ Note: `SynapseRequest._should_log_request` ensures that requests to
+ `/health` do not get logged at INFO.
+ """
+
+ isLeaf = 1
+
+ def render_GET(self, request):
+ request.setHeader(b"Content-Type", b"text/plain")
+ return b"OK"
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index ab671f7334..5db7f81c2d 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,23 +15,19 @@
import logging
from typing import Dict, Set
-from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json_bytes,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.util import json_decoder
logger = logging.getLogger(__name__)
-class RemoteKey(DirectServeResource):
- """HTTP resource for retreiving the TLS certificate and NACL signature
+class RemoteKey(DirectServeJsonResource):
+ """HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
isLeaf = True
def __init__(self, hs):
+ super().__init__()
+
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
- @wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- @wrap_json_request_handler
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
@@ -206,18 +202,22 @@ class RemoteKey(DirectServeResource):
if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
+ # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
for ts_added, result in results:
+ # Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
+ # If there is a cache miss, request the missing keys, then recurse (and
+ # ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
- key_json = json.loads(key_json)
+ key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
@@ -225,4 +225,4 @@ class RemoteKey(DirectServeResource):
results = {"server_keys": signed_keys}
- respond_with_json_bytes(request, 200, encode_canonical_json(results))
+ respond_with_json(request, 200, results, canonical_json=True)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 3689777266..6568e61829 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -16,10 +16,10 @@
import logging
import os
+import urllib
+from typing import Awaitable
-from six.moves import urllib
-
-from twisted.internet import defer
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -78,8 +78,9 @@ def respond_404(request):
)
-@defer.inlineCallbacks
-def respond_with_file(request, media_type, file_path, file_size=None, upload_name=None):
+async def respond_with_file(
+ request, media_type, file_path, file_size=None, upload_name=None
+):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -90,7 +91,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
+ await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@@ -199,8 +200,9 @@ def _can_encode_filename_as_token(x):
return True
-@defer.inlineCallbacks
-def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+async def respond_with_responder(
+ request, responder, media_type, file_size, upload_name=None
+):
"""Responds to the request with given responder. If responder is None then
returns 404.
@@ -219,7 +221,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
- yield responder.write_to_consumer(request)
+ await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us
@@ -233,21 +235,21 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
finish_request(request)
-class Responder(object):
+class Responder:
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
- consumer (IConsumer)
+ consumer: The consumer to stream into.
Returns:
- Deferred: Resolves once the response has finished being written
+ Resolves once the response has finished being written
"""
pass
@@ -258,7 +260,7 @@ class Responder(object):
pass
-class FileInfo(object):
+class FileInfo:
"""Details about a requested/uploaded file.
Attributes:
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 9f747de263..68dd2a1c8a 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -14,16 +14,10 @@
# limitations under the License.
#
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json,
- wrap_json_request_handler,
-)
-
-class MediaConfigResource(DirectServeResource):
+class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- @wrap_json_request_handler
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
- return NOT_DONE_YET
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 24d3ae5bbc..d3d8457303 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -15,18 +15,14 @@
import logging
import synapse.http.servlet
-from synapse.http.server import (
- DirectServeResource,
- set_cors_headers,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
-class DownloadResource(DirectServeResource):
+class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
self.media_repo = media_repo
self.server_name = hs.hostname
- # this is expected by @wrap_json_request_handler
- self.clock = hs.get_clock()
-
- @wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index e25c382c9c..d2826374a7 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -33,7 +33,7 @@ def _wrap_in_base_path(func):
return _wrapped
-class MediaFilePaths(object):
+class MediaFilePaths:
"""Describes where files are stored on disk.
Most of the functions have a `*_rel` variant which returns a file path that
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index fd10d42f2f..9a1b7779f7 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,12 +18,11 @@ import errno
import logging
import os
import shutil
-from typing import Dict, Tuple
-
-from six import iteritems
+from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
+from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -42,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
+ Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -62,7 +62,7 @@ logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
-class MediaRepository(object):
+class MediaRepository:
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
@@ -137,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
- self, media_type, upload_name, content, content_length, auth_user
- ):
+ self,
+ media_type: str,
+ upload_name: str,
+ content: IO,
+ content_length: int,
+ auth_user: str,
+ ) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type(str): The content type of the file
- upload_name(str): The name of the file
+ media_type: The content type of the file
+ upload_name: The name of the file
content: A file like object that is the content to store
- content_length(int): The length of the content
- auth_user(str): The user_id of the uploader
+ content_length: The length of the content
+ auth_user: The user_id of the uploader
Returns:
- Deferred[str]: The mxc url of the stored content
+ The mxc url of the stored content
"""
media_id = random_string(24)
@@ -172,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
- async def get_local_media(self, request, media_id, name):
+ async def get_local_media(
+ self, request: Request, media_id: str, name: Optional[str]
+ ) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
- request(twisted.web.http.Request)
- media_id (str): The media ID of the content. (This is the same as
+ request: The incoming request.
+ media_id: The media ID of the content. (This is the same as
the file_id for local content.)
- name (str|None): Optional name that, if specified, will be used as
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@@ -205,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
- async def get_remote_media(self, request, server_name, media_id, name):
+ async def get_remote_media(
+ self, request: Request, server_name: str, media_id: str, name: Optional[str]
+ ) -> None:
"""Respond to requests for remote media.
Args:
- request(twisted.web.http.Request)
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
- name (str|None): Optional name that, if specified, will be used as
+ request: The incoming request.
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
+ name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
- Deferred: Resolves once a response has successfully been written
- to request
+ Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@@ -247,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
- async def get_remote_media_info(self, server_name, media_id):
+ async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
- remote server).
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the remote server).
Returns:
- Deferred[dict]: The media_info of the file
+ The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@@ -280,7 +285,9 @@ class MediaRepository(object):
return media_info
- async def _get_remote_media_impl(self, server_name, media_id):
+ async def _get_remote_media_impl(
+ self, server_name: str, media_id: str
+ ) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -290,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
- Deferred[(Responder, media_info)]
+ A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -321,26 +328,28 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name, media_id, file_id):
+ async def _download_remote_file(
+ self, server_name: str, media_id: str, file_id: str
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
- server_name (str): Originating server
- media_id (str): The media ID of the content (as defined by the
+ server_name: Originating server
+ media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
- file_id (str): Local file ID
+ file_id: Local file ID
Returns:
- Deferred[MediaInfo]
+ The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
request_path = "/".join(
- ("/_matrix/media/v1/download", server_name, media_id)
+ ("/_matrix/media/r0/download", server_name, media_id)
)
try:
length, headers = await self.client.get_file(
@@ -551,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
- self, server_name, media_id, file_id, media_type, url_cache=False
- ):
+ self,
+ server_name: Optional[str],
+ media_id: str,
+ file_id: str,
+ media_type: str,
+ url_cache: bool = False,
+ ) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
- server_name (str|None): The server name if remote media, else None if local
- media_id (str): The media ID of the content. (This is the same as
+ server_name: The server name if remote media, else None if local
+ media_id: The media ID of the content. (This is the same as
the file_id for local content)
- file_id (str): Local file ID
- media_type (str): The content type of the file
- url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
+ file_id: Local file ID
+ media_type: The content type of the file
+ url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
- Deferred[dict]: Dict with "width" and "height" keys of original image
+ Dict with "width" and "height" keys of original image or None if the
+ media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
- return
+ return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@@ -586,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
@@ -606,7 +621,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it
- for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
+ for (t_width, t_height, t_type), t_method in thumbnails.items():
# Generate the thumbnail
if t_method == "crop":
t_byte_source = await defer_to_thread(
@@ -705,7 +720,7 @@ class MediaRepositoryResource(Resource):
Uploads are POSTed to a resource which returns a token which is used to GET
the download::
- => POST /_matrix/media/v1/upload HTTP/1.1
+ => POST /_matrix/media/r0/upload HTTP/1.1
Content-Type: <media-type>
Content-Length: <content-length>
@@ -716,7 +731,7 @@ class MediaRepositoryResource(Resource):
{ "content_uri": "mxc://<server-name>/<media-id>" }
- => GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1
+ => GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: <media-type>
@@ -727,7 +742,7 @@ class MediaRepositoryResource(Resource):
Clients can get thumbnails by supplying a desired width and height and
thumbnailing method::
- => GET /_matrix/media/v1/thumbnail/<server_name>
+ => GET /_matrix/media/r0/thumbnail/<server_name>
/<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
<= HTTP/1.1 200 OK
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 683a79c966..3a352b5631 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -12,73 +12,79 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import contextlib
import logging
import os
import shutil
-import sys
-
-import six
+from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
-from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
-from ._base import Responder
+from ._base import FileInfo, Responder
+from .filepath import MediaFilePaths
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+ from .storage_provider import StorageProviderWrapper
logger = logging.getLogger(__name__)
-class MediaStorage(object):
+class MediaStorage:
"""Responsible for storing/fetching files from local sources.
Args:
- hs (synapse.server.Homeserver)
- local_media_directory (str): Base path where we store media on disk
- filepaths (MediaFilePaths)
- storage_providers ([StorageProvider]): List of StorageProvider that are
- used to fetch and store files.
+ hs
+ local_media_directory: Base path where we store media on disk
+ filepaths
+ storage_providers: List of StorageProvider that are used to fetch and store files.
"""
- def __init__(self, hs, local_media_directory, filepaths, storage_providers):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ local_media_directory: str,
+ filepaths: MediaFilePaths,
+ storage_providers: Sequence["StorageProviderWrapper"],
+ ):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
- @defer.inlineCallbacks
- def store_file(self, source, file_info):
+ async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Returns:
- Deferred[str]: the file path written to in the primary media store
+ the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), _write_file_synchronously, source, f
)
- yield finish_cb()
+ await finish_cb()
return fname
@contextlib.contextmanager
- def store_into_file(self, file_info):
+ def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
- on disk, and finish_cb is a function that returns a Deferred.
+ on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
@@ -88,13 +94,13 @@ class MediaStorage(object):
error.
Args:
- file_info (FileInfo): Info about the file to store
+ file_info: Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
- yield finish_cb()
+ await finish_cb()
"""
path = self._file_info_to_path(file_info)
@@ -106,10 +112,9 @@ class MediaStorage(object):
finished_called = [False]
- @defer.inlineCallbacks
- def finish():
+ async def finish():
for provider in self.storage_providers:
- yield provider.store_file(path, file_info)
+ await provider.store_file(path, file_info)
finished_called[0] = True
@@ -117,27 +122,24 @@ class MediaStorage(object):
with open(fname, "wb") as f:
yield f, fname, finish
except Exception:
- t, v, tb = sys.exc_info()
try:
os.remove(fname)
except Exception:
pass
- six.reraise(t, v, tb)
+ raise
if not finished_called:
raise Exception("Finished callback not called")
- @defer.inlineCallbacks
- def fetch_media(self, file_info):
+ async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[Responder|None]: Returns a Responder if the file was found,
- otherwise None.
+ Returns a Responder if the file was found, otherwise None.
"""
path = self._file_info_to_path(file_info)
@@ -146,23 +148,22 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = await provider.fetch(path, file_info) # type: Any
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
return None
- @defer.inlineCallbacks
- def ensure_media_is_in_local_cache(self, file_info):
+ async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
- file_info (FileInfo)
+ file_info
Returns:
- Deferred[str]: Full path to local file
+ Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
@@ -174,29 +175,23 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
- res = yield provider.fetch(path, file_info)
+ res = await provider.fetch(path, file_info) # type: Any
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.hs.get_reactor()
)
- yield res.write_to_consumer(consumer)
- yield consumer.wait()
+ await res.write_to_consumer(consumer)
+ await consumer.wait()
return local_path
raise Exception("file could not be found")
- def _file_info_to_path(self, file_info):
+ def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
-
- Args:
- file_info (FileInfo)
-
- Returns:
- str
"""
if file_info.url_cache:
if file_info.thumbnail:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index f206605727..cd8c246594 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -24,28 +24,24 @@ import shutil
import sys
import traceback
from typing import Dict, Optional
+from urllib import parse as urlparse
-import six
-from six import string_types
-from six.moves import urllib_parse as urlparse
+import attr
-from canonicaljson import json
-
-from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
- DirectServeResource,
+ DirectServeJsonResource,
respond_with_json,
respond_with_json_bytes,
- wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.stringutils import random_string
@@ -60,8 +56,67 @@ _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
+ONE_HOUR = 60 * 60 * 1000
+
+# A map of globs to API endpoints.
+_oembed_globs = {
+ # Twitter.
+ "https://publish.twitter.com/oembed": [
+ "https://twitter.com/*/status/*",
+ "https://*.twitter.com/*/status/*",
+ "https://twitter.com/*/moments/*",
+ "https://*.twitter.com/*/moments/*",
+ # Include the HTTP versions too.
+ "http://twitter.com/*/status/*",
+ "http://*.twitter.com/*/status/*",
+ "http://twitter.com/*/moments/*",
+ "http://*.twitter.com/*/moments/*",
+ ],
+}
+# Convert the globs to regular expressions.
+_oembed_patterns = {}
+for endpoint, globs in _oembed_globs.items():
+ for glob in globs:
+ # Convert the glob into a sane regular expression to match against. The
+ # rules followed will be slightly different for the domain portion vs.
+ # the rest.
+ #
+ # 1. The scheme must be one of HTTP / HTTPS (and have no globs).
+ # 2. The domain can have globs, but we limit it to characters that can
+ # reasonably be a domain part.
+ # TODO: This does not attempt to handle Unicode domain names.
+ # 3. Other parts allow a glob to be any one, or more, characters.
+ results = urlparse.urlparse(glob)
+
+ # Ensure the scheme does not have wildcards (and is a sane scheme).
+ if results.scheme not in {"http", "https"}:
+ raise ValueError("Insecure oEmbed glob scheme: %s" % (results.scheme,))
+
+ pattern = urlparse.urlunparse(
+ [
+ results.scheme,
+ re.escape(results.netloc).replace("\\*", "[a-zA-Z0-9_-]+"),
+ ]
+ + [re.escape(part).replace("\\*", ".+") for part in results[2:]]
+ )
+ _oembed_patterns[re.compile(pattern)] = endpoint
+
+
+@attr.s
+class OEmbedResult:
+ # Either HTML content or URL must be provided.
+ html = attr.ib(type=Optional[str])
+ url = attr.ib(type=Optional[str])
+ title = attr.ib(type=Optional[str])
+ # Number of seconds to cache the content.
+ cache_age = attr.ib(type=int)
-class PreviewUrlResource(DirectServeResource):
+
+class OEmbedError(Exception):
+ """An error occurred processing the oEmbed object."""
+
+
+class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -85,6 +140,15 @@ class PreviewUrlResource(DirectServeResource):
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
+ # We run the background jobs if we're the instance specified (or no
+ # instance is specified, where we assume there is only one instance
+ # serving media).
+ instance_running_jobs = hs.config.media.media_instance_running_background_jobs
+ self._worker_run_media_background_jobs = (
+ instance_running_jobs is None
+ or instance_running_jobs == hs.get_instance_name()
+ )
+
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
self.url_preview_accept_language = hs.config.url_preview_accept_language
@@ -94,18 +158,18 @@ class PreviewUrlResource(DirectServeResource):
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
- expiry_ms=60 * 60 * 1000,
+ expiry_ms=ONE_HOUR,
)
- self._cleaner_loop = self.clock.looping_call(
- self._start_expire_url_cache_data, 10 * 1000
- )
+ if self._worker_run_media_background_jobs:
+ self._cleaner_loop = self.clock.looping_call(
+ self._start_expire_url_cache_data, 10 * 1000
+ )
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET")
- return respond_with_json(request, 200, {}, send_cors=True)
+ respond_with_json(request, 200, {}, send_cors=True)
- @wrap_json_request_handler
async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
@@ -163,19 +227,19 @@ class PreviewUrlResource(DirectServeResource):
else:
logger.info("Returning cached response")
- og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
+ og = await make_deferred_yieldable(observable.observe())
respond_with_json_bytes(request, 200, og, send_cors=True)
- async def _do_preview(self, url, user, ts):
+ async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
- url (str):
- user (str):
- ts (int):
+ url: The URL to preview.
+ user: The user requesting the preview.
+ ts: The timestamp requested for the preview.
Returns:
- Deferred[bytes]: json-encoded og data
+ json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
@@ -188,7 +252,7 @@ class PreviewUrlResource(DirectServeResource):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
- if isinstance(og, six.text_type):
+ if isinstance(og, str):
og = og.encode("utf8")
return og
@@ -290,7 +354,7 @@ class PreviewUrlResource(DirectServeResource):
logger.debug("Calculated OG for %s as %s", url, og)
- jsonog = json.dumps(og)
+ jsonog = json_encoder.encode(og)
# store OG in history-aware DB cache
await self.store.store_url_cache(
@@ -305,6 +369,87 @@ class PreviewUrlResource(DirectServeResource):
return jsonog.encode("utf8")
+ def _get_oembed_url(self, url: str) -> Optional[str]:
+ """
+ Check whether the URL should be downloaded as oEmbed content instead.
+
+ Params:
+ url: The URL to check.
+
+ Returns:
+ A URL to use instead or None if the original URL should be used.
+ """
+ for url_pattern, endpoint in _oembed_patterns.items():
+ if url_pattern.fullmatch(url):
+ return endpoint
+
+ # No match.
+ return None
+
+ async def _get_oembed_content(self, endpoint: str, url: str) -> OEmbedResult:
+ """
+ Request content from an oEmbed endpoint.
+
+ Params:
+ endpoint: The oEmbed API endpoint.
+ url: The URL to pass to the API.
+
+ Returns:
+ An object representing the metadata returned.
+
+ Raises:
+ OEmbedError if fetching or parsing of the oEmbed information fails.
+ """
+ try:
+ logger.debug("Trying to get oEmbed content for url '%s'", url)
+ result = await self.client.get_json(
+ endpoint,
+ # TODO Specify max height / width.
+ # Note that only the JSON format is supported.
+ args={"url": url},
+ )
+
+ # Ensure there's a version of 1.0.
+ if result.get("version") != "1.0":
+ raise OEmbedError("Invalid version: %s" % (result.get("version"),))
+
+ oembed_type = result.get("type")
+
+ # Ensure the cache age is None or an int.
+ cache_age = result.get("cache_age")
+ if cache_age:
+ cache_age = int(cache_age)
+
+ oembed_result = OEmbedResult(None, None, result.get("title"), cache_age)
+
+ # HTML content.
+ if oembed_type == "rich":
+ oembed_result.html = result.get("html")
+ return oembed_result
+
+ if oembed_type == "photo":
+ oembed_result.url = result.get("url")
+ return oembed_result
+
+ # TODO Handle link and video types.
+
+ if "thumbnail_url" in result:
+ oembed_result.url = result.get("thumbnail_url")
+ return oembed_result
+
+ raise OEmbedError("Incompatible oEmbed information.")
+
+ except OEmbedError as e:
+ # Trap OEmbedErrors first so we can directly re-raise them.
+ logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
+ raise
+
+ except Exception as e:
+ # Trap any exception and let the code follow as usual.
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
+ raise OEmbedError() from e
+
async def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
@@ -314,54 +459,90 @@ class PreviewUrlResource(DirectServeResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ # If this URL can be accessed via oEmbed, use that instead.
+ url_to_download = url
+ oembed_url = self._get_oembed_url(url)
+ if oembed_url:
+ # The result might be a new URL to download, or it might be HTML content.
try:
- logger.debug("Trying to get preview for url '%s'", url)
- length, headers, uri, code = await self.client.get_file(
- url,
- output_stream=f,
- max_size=self.max_spider_size,
- headers={"Accept-Language": self.url_preview_accept_language},
- )
- except SynapseError:
- # Pass SynapseErrors through directly, so that the servlet
- # handler will return a SynapseError to the client instead of
- # blank data or a 500.
- raise
- except DNSLookupError:
- # DNS lookup returned no results
- # Note: This will also be the case if one of the resolved IP
- # addresses is blacklisted
- raise SynapseError(
- 502,
- "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN,
- )
- except Exception as e:
- # FIXME: pass through 404s and other error messages nicely
- logger.warning("Error downloading %s: %r", url, e)
+ oembed_result = await self._get_oembed_content(oembed_url, url)
+ if oembed_result.url:
+ url_to_download = oembed_result.url
+ elif oembed_result.html:
+ url_to_download = None
+ except OEmbedError:
+ # If an error occurs, try doing a normal preview.
+ pass
- raise SynapseError(
- 500,
- "Failed to download content: %s"
- % (traceback.format_exception_only(sys.exc_info()[0], e),),
- Codes.UNKNOWN,
- )
- await finish()
+ if url_to_download:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ logger.debug("Trying to get preview for url '%s'", url_to_download)
+ length, headers, uri, code = await self.client.get_file(
+ url_to_download,
+ output_stream=f,
+ max_size=self.max_spider_size,
+ headers={"Accept-Language": self.url_preview_accept_language},
+ )
+ except SynapseError:
+ # Pass SynapseErrors through directly, so that the servlet
+ # handler will return a SynapseError to the client instead of
+ # blank data or a 500.
+ raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
+ )
+ except Exception as e:
+ # FIXME: pass through 404s and other error messages nicely
+ logger.warning("Error downloading %s: %r", url_to_download, e)
+
+ raise SynapseError(
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
+ Codes.UNKNOWN,
+ )
+ await finish()
+
+ if b"Content-Type" in headers:
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ else:
+ media_type = "application/octet-stream"
+
+ download_name = get_filename_from_headers(headers)
+
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ expires = ONE_HOUR
+ etag = headers["ETag"][0] if "ETag" in headers else None
+ else:
+ html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ f.write(html_bytes)
+ await finish()
+
+ media_type = "text/html"
+ download_name = oembed_result.title
+ length = len(html_bytes)
+ # If a specific cache age was not given, assume 1 hour.
+ expires = oembed_result.cache_age or ONE_HOUR
+ uri = oembed_url
+ code = 200
+ etag = None
try:
- if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode("ascii")
- else:
- media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
- download_name = get_filename_from_headers(headers)
-
await self.store.store_local_media(
media_id=file_id,
media_type=media_type,
- time_now_ms=self.clock.time_msec(),
+ time_now_ms=time_now_ms,
upload_name=download_name,
media_length=length,
user_id=user,
@@ -384,10 +565,8 @@ class PreviewUrlResource(DirectServeResource):
"filename": fname,
"uri": uri,
"response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
+ "expires": expires,
+ "etag": etag,
}
def _start_expire_url_cache_data(self):
@@ -400,11 +579,13 @@ class PreviewUrlResource(DirectServeResource):
"""
# TODO: Delete from backup media store
+ assert self._worker_run_media_background_jobs
+
now = self.clock.time_msec()
logger.debug("Running url preview cache expiry")
- if not (await self.store.db.updates.has_completed_background_updates()):
+ if not (await self.store.db_pool.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
@@ -442,7 +623,7 @@ class PreviewUrlResource(DirectServeResource):
# These may be cached for a bit on the client (i.e., they
# may have a room open with a preview url thing open).
# So we wait a couple of days before deleting, just in case.
- expire_before = now - 2 * 24 * 60 * 60 * 1000
+ expire_before = now - 2 * 24 * ONE_HOUR
media_ids = await self.store.get_url_cache_media_before(expire_before)
removed_media = []
@@ -631,7 +812,7 @@ def _iterate_over_text(tree, *tags_to_ignore):
if el is None:
return
- if isinstance(el, string_types):
+ if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..18c9ed48d6 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,65 +13,66 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
import os
import shutil
-
-from twisted.internet import defer
+from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
-class StorageProvider(object):
+class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
-
- Returns:
- Deferred
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
"""
- pass
- def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
- path (str): Relative path of file in local cache
- file_info (FileInfo)
+ path: Relative path of file in local cache
+ file_info: The metadata of the file.
Returns:
- Deferred(Responder): Returns a Responder if the provider has the file,
- otherwise returns None.
+ Returns a Responder if the provider has the file, otherwise returns None.
"""
- pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
- backend (StorageProvider)
- store_local (bool): Whether to store new local files or not.
- store_synchronous (bool): Whether to wait for file to be successfully
+ backend: The storage provider to wrap.
+ store_local: Whether to store new local files or not.
+ store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
- store_remote (bool): Whether remote media should be uploaded
+ store_remote: Whether remote media should be uploaded
"""
- def __init__(self, backend, store_local, store_synchronous, store_remote):
+ def __init__(
+ self,
+ backend: StorageProvider,
+ store_local: bool,
+ store_synchronous: bool,
+ store_remote: bool,
+ ):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@@ -80,28 +81,38 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
- return defer.succeed(None)
+ return None
if file_info.server_name and not self.store_remote:
- return defer.succeed(None)
+ return None
if self.store_synchronous:
- return self.backend.store_file(path, file_info)
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
else:
# TODO: Handle errors.
- def store():
+ async def store():
try:
- return self.backend.store_file(path, file_info)
+ result = self.backend.store_file(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
except Exception:
logger.exception("Error storing file")
run_in_background(store)
- return defer.succeed(None)
+ return None
- def fetch(self, path, file_info):
- return self.backend.fetch(path, file_info)
+ async def fetch(self, path, file_info):
+ # store_file is supposed to return an Awaitable, but guard
+ # against improper implementations.
+ result = self.backend.fetch(path, file_info)
+ if inspect.isawaitable(result):
+ return await result
class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +131,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- def store_file(self, path, file_info):
+ async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +141,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return defer_to_thread(
+ return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- def fetch(self, path, file_info):
+ async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 0b87220234..a83535b97b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,11 +16,7 @@
import logging
-from synapse.http.server import (
- DirectServeResource,
- set_cors_headers,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from ._base import (
@@ -34,7 +30,7 @@ from ._base import (
logger = logging.getLogger(__name__)
-class ThumbnailResource(DirectServeResource):
+class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- self.clock = hs.get_clock()
- @wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index c234ea7421..d681bf7bf0 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from io import BytesIO
-import PIL.Image as Image
+from PIL import Image as Image
logger = logging.getLogger(__name__)
@@ -32,7 +31,7 @@ EXIF_TRANSPOSE_MAPPINGS = {
}
-class Thumbnailer(object):
+class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 83d005812d..3ebf7a68e6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,20 +15,14 @@
import logging
-from twisted.web.server import NOT_DONE_YET
-
from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import (
- DirectServeResource,
- respond_with_json,
- wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
-class UploadResource(DirectServeResource):
+class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- def render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
- return NOT_DONE_YET
- @wrap_json_request_handler
async def _async_render_POST(self, request):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py
index c03194f001..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/oidc/callback_resource.py
@@ -14,18 +14,17 @@
# limitations under the License.
import logging
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.http.server import DirectServeHtmlResource
logger = logging.getLogger(__name__)
-class OIDCCallbackResource(DirectServeResource):
+class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
def __init__(self, hs):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
- @wrap_html_request_handler
async def _async_render_GET(self, request):
- return await self._oidc_handler.handle_oidc_callback(request)
+ await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 75e58043b4..c10188a5d7 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -16,10 +16,10 @@
from twisted.python import failure
from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource, return_html_error
-class SAML2ResponseResource(DirectServeResource):
+class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 20177b44e7..f591cc6c5c 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -13,17 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
from twisted.web.resource import Resource
from synapse.http.server import set_cors_headers
+from synapse.util import json_encoder
logger = logging.getLogger(__name__)
-class WellKnownBuilder(object):
+class WellKnownBuilder:
"""Utility to construct the well-known response
Args:
@@ -67,4 +67,4 @@ class WellKnownResource(Resource):
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
- return json.dumps(r).encode("utf-8")
+ return json_encoder.encode(r).encode("utf-8")
|