diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 87f927890c..40f5c32db2 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -13,8 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.rest.admin
from synapse.http.server import JsonResource
+from synapse.rest import admin
from synapse.rest.client import versions
from synapse.rest.client.v1 import (
directory,
@@ -123,9 +123,7 @@ class ClientRestResource(JsonResource):
password_policy.register_servlets(hs, client_resource)
# moving to /_synapse/admin
- synapse.rest.admin.register_servlets_for_client_rest_resource(
- hs, client_resource
- )
+ admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
shared_rooms.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 1c88c93f38..57cac22252 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -16,13 +16,13 @@
import logging
import platform
-import re
import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
historical_admin_path_patterns,
)
@@ -31,6 +31,7 @@ from synapse.rest.admin.devices import (
DeviceRestServlet,
DevicesRestServlet,
)
+from synapse.rest.admin.event_reports import EventReportsRestServlet
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -49,6 +50,7 @@ from synapse.rest.admin.users import (
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
+ UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
UsersRestServlet,
@@ -61,7 +63,7 @@ logger = logging.getLogger(__name__)
class VersionServlet(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
+ PATTERNS = admin_patterns("/server_version$")
def __init__(self, hs):
self.res = {
@@ -107,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- token = await self.store.get_topological_token_for_event(event_id)
+ room_token = await self.store.get_topological_token_for_event(event_id)
+ token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:
@@ -209,11 +212,13 @@ def register_servlets(hs, http_server):
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
+ UserMembershipRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
+ EventReportsRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d82eaf5e38..db9fea263a 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex):
]
-def admin_patterns(path_regex: str):
+def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
Args:
@@ -54,7 +54,7 @@ def admin_patterns(path_regex: str):
Returns:
A list of regex patterns.
"""
- admin_prefix = "^/_synapse/admin/v1"
+ admin_prefix = "^/_synapse/admin/" + version
patterns = [re.compile(admin_prefix + path_regex)]
return patterns
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 8d32677339..a163863322 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import re
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -21,7 +20,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
-from synapse.rest.admin._base import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -32,14 +31,12 @@ class DeviceRestServlet(RestServlet):
Get, update or delete the given user's device
"""
- PATTERNS = (
- re.compile(
- "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
- ),
+ PATTERNS = admin_patterns(
+ "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
def __init__(self, hs):
- super(DeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet):
Retrieve the given user's devices
"""
- PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs):
"""
@@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet):
key which lists the device_ids to delete.
"""
- PATTERNS = (
- re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
- )
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs):
self.hs = hs
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
new file mode 100644
index 0000000000..5b8d0594cd
--- /dev/null
+++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+
+logger = logging.getLogger(__name__)
+
+
+class EventReportsRestServlet(RestServlet):
+ """
+ List all reported events that are known to the homeserver. Results are returned
+ in a dictionary containing report information. Supports pagination.
+ The requester must have administrator access in Synapse.
+
+ GET /_synapse/admin/v1/event_reports
+ returns:
+ 200 OK with list of reports if success otherwise an error.
+
+ Args:
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `dir` can be used to define the order of results.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `room_id` can be used to filter by room id.
+ Returns:
+ A list of reported events and an integer representing the total number of
+ reported events that exist given this query
+ """
+
+ PATTERNS = admin_patterns("/event_reports$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ direction = parse_string(request, "dir", default="b")
+ user_id = parse_string(request, "user_id")
+ room_id = parse_string(request, "room_id")
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "The start parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "The limit parameter must be a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ event_reports, total = await self.store.get_event_reports_paginate(
+ start, limit, direction, user_id, room_id
+ )
+ ret = {"event_reports": event_reports, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(event_reports)
+
+ return 200, ret
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index f474066542..8b7bb6d44e 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,14 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
class PurgeRoomServlet(RestServlet):
@@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
+ PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 6e9a874121..375d055445 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import re
-
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -22,6 +20,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import UserID
@@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet):
self.snm = hs.get_server_notices_manager()
def register(self, json_resource):
- PATTERN = "^/_synapse/admin/v1/send_server_notice"
+ PATTERN = "/send_server_notice"
json_resource.register_paths(
- "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
+ "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
- (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
+ admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"),
self.on_PUT,
self.__class__.__name__,
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f3e77da850..20dc1d0e05 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -15,7 +15,6 @@
import hashlib
import hmac
import logging
-import re
from http import HTTPStatus
from synapse.api.constants import UserTypes
@@ -29,6 +28,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
@@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet):
class UsersRestServletV2(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+ PATTERNS = admin_patterns("/users$", "v2")
"""Get request to list all local users.
This needs user to have administrator access in Synapse.
@@ -105,7 +105,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@@ -642,7 +642,7 @@ class UserAdminServlet(RestServlet):
{}
"""
- PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs):
self.hs = hs
@@ -683,3 +683,29 @@ class UserAdminServlet(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
return 200, {}
+
+
+class UserMembershipRestServlet(RestServlet):
+ """
+ Get room list of an user.
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+
+ def __init__(self, hs):
+ self.is_mine = hs.is_mine
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only lookup local users")
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ raise NotFoundError("User not found")
+
+ ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
+ return 200, ret
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index b210015173..faabeeb91c 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet):
PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryListServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
)
def __init__(self, hs):
- super(ClientAppserviceDirectoryListServer, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 25effd0261..1ecb77aa26 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -30,9 +30,10 @@ class EventStreamRestServlet(RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
- super(EventStreamRestServlet, self).__init__()
+ super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii")
- pagin_config = PaginationConfig.from_request(request)
+ pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args:
try:
@@ -74,7 +75,7 @@ class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(EventRestServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 910b3b4eeb..91da0ee573 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -24,14 +24,15 @@ class InitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
- super(InitialSyncRestServlet, self).__init__()
+ super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
- pagination_config = PaginationConfig.from_request(request)
+ pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a14618ac84..3d1693d7ac 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.appservice import ApplicationService
from synapse.handlers.auth import (
convert_client_dict_legacy_fields_to_identifier,
login_id_phone_to_thirdparty,
@@ -44,9 +45,10 @@ class LoginRestServlet(RestServlet):
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
+ APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
def __init__(self, hs):
- super(LoginRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
# JWT configuration variables.
@@ -61,6 +63,8 @@ class LoginRestServlet(RestServlet):
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self.auth = hs.get_auth()
+
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -116,8 +120,12 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
+
try:
- if self.jwt_enabled and (
+ if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
+ appservice = self.auth.get_appservice_by_req(request)
+ result = await self._do_appservice_login(login_submission, appservice)
+ elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
@@ -134,6 +142,33 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
+ def _get_qualified_user_id(self, identifier):
+ if identifier["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+ if "user" not in identifier:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if identifier["user"].startswith("@"):
+ return identifier["user"]
+ else:
+ return UserID(identifier["user"], self.hs.hostname).to_string()
+
+ async def _do_appservice_login(
+ self, login_submission: JsonDict, appservice: ApplicationService
+ ):
+ logger.info(
+ "Got appservice login request with identifier: %r",
+ login_submission.get("identifier"),
+ )
+
+ identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
+ qualified_user_id = self._get_qualified_user_id(identifier)
+
+ if not appservice.is_interested_in_user(qualified_user_id):
+ raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
+
+ return await self._complete_login(qualified_user_id, login_submission)
+
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -219,15 +254,7 @@ class LoginRestServlet(RestServlet):
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- qualified_user_id = identifier["user"]
- else:
- qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+ qualified_user_id = self._get_qualified_user_id(identifier)
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
@@ -255,9 +282,7 @@ class LoginRestServlet(RestServlet):
self,
user_id: str,
login_submission: JsonDict,
- callback: Optional[
- Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
- ] = None,
+ callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
@@ -270,12 +295,12 @@ class LoginRestServlet(RestServlet):
Args:
user_id: ID of the user to register.
login_submission: Dictionary of login information.
- callback: Callback function to run after registration.
+ callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
Returns:
- result: Dictionary of account information after successful registration.
+ result: Dictionary of account information after successful login.
"""
# Before we actually log them in we check if they've already logged in
@@ -310,14 +335,24 @@ class LoginRestServlet(RestServlet):
return result
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+ """
+ Handle the final stage of SSO login.
+
+ Args:
+ login_submission: The JSON request body.
+
+ Returns:
+ The body of the JSON response.
+ """
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
- result = await self._complete_login(user_id, login_submission)
- return result
+ return await self._complete_login(
+ user_id, login_submission, self.auth_handler._sso_login_callback
+ )
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
@@ -400,7 +435,7 @@ class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
- super(CasTicketServlet, self).__init__()
+ super().__init__()
self._cas_handler = hs.get_cas_handler()
async def on_GET(self, request: SynapseRequest) -> None:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b0c30b65be..f792b50cdc 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet):
PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
- super(LogoutRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet):
PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
- super(LogoutAllRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 970fdd5834..79d8e3057f 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
- super(PresenceStatusRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e7fe50ed72..b686cd671f 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -25,7 +25,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
- super(ProfileDisplaynameRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@@ -73,7 +73,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
- super(ProfileAvatarURLRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@@ -124,7 +124,7 @@ class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
- super(ProfileRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e781a3bcf4..f9eecb7cf5 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -38,7 +38,7 @@ class PushRuleRestServlet(RestServlet):
)
def __init__(self, hs):
- super(PushRuleRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet):
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
+ if spec["attr"] not in ("enabled", "actions"):
+ # for the sake of potential future expansion, shouldn't report
+ # 404 in the case of an unknown request so check it corresponds to
+ # a known attribute first.
+ raise UnrecognizedRequestError()
+
+ namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+ rule_id = spec["rule_id"]
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return await self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val
+ user_id, namespaced_rule_id, val, is_default_rule
)
elif spec["attr"] == "actions":
actions = val.get("actions")
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 5f65cb7d83..28dabf1c7a 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
- super(PushersRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
- super(PushersSetRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(PushersRemoveRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 84baf3d59b..b63389e5fe 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -57,7 +57,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
def __init__(self, hs):
- super(TransactionRestServlet, self).__init__()
+ super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -65,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
- super(RoomCreateRestServlet, self).__init__(hs)
+ super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
@@ -111,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomStateEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomSendEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -280,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
- super(JoinRoomAliasServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -343,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
- super(PublicRoomListRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@@ -448,9 +448,10 @@ class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
- super(RoomMemberListRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
if at_token_string is None:
at_token = None
else:
- at_token = StreamToken.from_string(at_token_string)
+ at_token = await StreamToken.from_string(self.store, at_token_string)
# let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter
@@ -499,7 +500,7 @@ class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
- super(JoinedRoomMemberListRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@@ -518,13 +519,16 @@ class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
- super(RoomMessageListRestServlet, self).__init__()
+ super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(request, default_limit=10)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
@@ -557,7 +561,7 @@ class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
- super(RoomStateRestServlet, self).__init__()
+ super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@@ -577,13 +581,14 @@ class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
- super(RoomInitialSyncRestServlet, self).__init__()
+ super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(request)
+ pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
@@ -596,7 +601,7 @@ class RoomEventServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomEventServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
@@ -628,7 +633,7 @@ class RoomEventContextServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomEventContextServlet, self).__init__()
+ super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
@@ -675,7 +680,7 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomForgetRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -701,7 +706,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomMembershipRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
@@ -792,7 +797,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
- super(RoomRedactEventRestServlet, self).__init__(hs)
+ super().__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -841,7 +846,7 @@ class RoomTypingRestServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomTypingRestServlet, self).__init__()
+ super().__init__()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@@ -914,7 +919,7 @@ class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
- super(SearchRestServlet, self).__init__()
+ super().__init__()
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@@ -935,7 +940,7 @@ class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
- super(JoinedRoomsRestServlet, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 50277c6cf6..b8d491ca5c 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet):
PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
- super(VoipRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a206b75541..86d3d86fad 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -17,6 +17,11 @@
import logging
import random
from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -47,7 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
- super(EmailPasswordRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
self.config = hs.config
@@ -91,6 +96,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -137,86 +146,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret
-class PasswordResetSubmitTokenServlet(RestServlet):
- """Handles 3PID validation token submission"""
-
- PATTERNS = client_patterns(
- "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
- )
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- super(PasswordResetSubmitTokenServlet, self).__init__()
- self.hs = hs
- self.auth = hs.get_auth()
- self.config = hs.config
- self.clock = hs.get_clock()
- self.store = hs.get_datastore()
- if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
- self._failure_email_template = (
- self.config.email_password_reset_template_failure_html
- )
-
- async def on_GET(self, request, medium):
- # We currently only handle threepid token submissions for email
- if medium != "email":
- raise SynapseError(
- 400, "This medium is currently not supported for password resets"
- )
- if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Password reset emails have been disabled due to lack of an email config"
- )
- raise SynapseError(
- 400, "Email-based password resets are disabled on this server"
- )
-
- sid = parse_string(request, "sid", required=True)
- token = parse_string(request, "token", required=True)
- client_secret = parse_string(request, "client_secret", required=True)
- assert_valid_client_secret(client_secret)
-
- # Attempt to validate a 3PID session
- try:
- # Mark the session as valid
- next_link = await self.store.validate_threepid_session(
- sid, client_secret, token, self.clock.time_msec()
- )
-
- # Perform a 302 redirect if next_link is set
- if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
-
- # Otherwise show the success template
- html = self.config.email_password_reset_template_success_html_content
- status_code = 200
- except ThreepidValidationError as e:
- status_code = e.code
-
- # Show a failure page with a reason
- template_vars = {"failure_reason": e.msg}
- html = self._failure_email_template.render(**template_vars)
-
- respond_with_html(request, status_code, html)
-
-
class PasswordRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password$")
def __init__(self, hs):
- super(PasswordRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -342,7 +276,7 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs):
- super(DeactivateAccountRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -361,7 +295,7 @@ class DeactivateAccountRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
- # allow ASes to dectivate their own users
+ # allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
@@ -390,7 +324,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
- super(EmailThreepidRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.config = hs.config
self.identity_handler = hs.get_handlers().identity_handler
@@ -439,6 +373,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
@@ -484,7 +422,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
- super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@@ -510,6 +448,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ if next_link:
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -596,15 +538,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
# Otherwise show the success template
html = self.config.email_add_threepid_template_success_html_content
@@ -665,7 +602,7 @@ class ThreepidRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs):
- super(ThreepidRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -721,7 +658,7 @@ class ThreepidAddRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/add$")
def __init__(self, hs):
- super(ThreepidAddRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -772,7 +709,7 @@ class ThreepidBindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/bind$")
def __init__(self, hs):
- super(ThreepidBindRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -801,7 +738,7 @@ class ThreepidUnbindRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/unbind$")
def __init__(self, hs):
- super(ThreepidUnbindRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -832,7 +769,7 @@ class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs):
- super(ThreepidDeleteRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -868,11 +805,50 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+ """
+ Raises a SynapseError if a given next_link value is invalid
+
+ next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+ option is either empty or contains a domain that matches the one in the given next_link
+
+ Args:
+ hs: The homeserver object
+ next_link: The next_link value given by the client
+
+ Raises:
+ SynapseError: If the next_link is invalid
+ """
+ valid = True
+
+ # Parse the contents of the URL
+ next_link_parsed = urlparse(next_link)
+
+ # Scheme must not point to the local drive
+ if next_link_parsed.scheme == "file":
+ valid = False
+
+ # If the domain whitelist is set, the domain must be in it
+ if (
+ valid
+ and hs.config.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ ):
+ valid = False
+
+ if not valid:
+ raise SynapseError(
+ 400,
+ "'next_link' domain not included in whitelist, or not http(s)",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs):
- super(WhoamiRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
async def on_GET(self, request):
@@ -883,7 +859,6 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index c1d4cd0caf..87a5b1b86b 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -34,7 +34,7 @@ class AccountDataServlet(RestServlet):
)
def __init__(self, hs):
- super(AccountDataServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -86,7 +86,7 @@ class RoomAccountDataServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomAccountDataServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d06336ceea..bd7f9ae203 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValidityRenewServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
@@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValiditySendMailServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e585e9153..5fbfae5991 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -25,94 +25,6 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
-RECAPTCHA_TEMPLATE = """
-<html>
-<head>
-<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<script src="https://www.recaptcha.net/recaptcha/api.js"
- async defer></script>
-<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-function captchaDone() {
- $('#registrationForm').submit();
-}
-</script>
-</head>
-<body>
-<form id="registrationForm" method="post" action="%(myurl)s">
- <div>
- <p>
- Hello! We need to prevent computer programs and other automated
- things from creating accounts on this server.
- </p>
- <p>
- Please verify that you're not a robot.
- </p>
- <input type="hidden" name="session" value="%(session)s" />
- <div class="g-recaptcha"
- data-sitekey="%(sitekey)s"
- data-callback="captchaDone">
- </div>
- <noscript>
- <input type="submit" value="All Done" />
- </noscript>
- </div>
- </div>
-</form>
-</body>
-</html>
-"""
-
-TERMS_TEMPLATE = """
-<html>
-<head>
-<title>Authentication</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-</head>
-<body>
-<form id="registrationForm" method="post" action="%(myurl)s">
- <div>
- <p>
- Please click the button below if you agree to the
- <a href="%(terms_url)s">privacy policy of this homeserver.</a>
- </p>
- <input type="hidden" name="session" value="%(session)s" />
- <input type="submit" value="Agree" />
- </div>
-</form>
-</body>
-</html>
-"""
-
-SUCCESS_TEMPLATE = """
-<html>
-<head>
-<title>Success!</title>
-<meta name='viewport' content='width=device-width, initial-scale=1,
- user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
-<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
-<script>
-if (window.onAuthDone) {
- window.onAuthDone();
-} else if (window.opener && window.opener.postMessage) {
- window.opener.postMessage("authDone", "*");
-}
-</script>
-</head>
-<body>
- <div>
- <p>Thank you</p>
- <p>You may now close this window and return to the application</p>
- </div>
-</body>
-</html>
-"""
-
class AuthRestServlet(RestServlet):
"""
@@ -124,7 +36,7 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
- super(AuthRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -145,26 +57,30 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
+ self.recaptcha_template = hs.config.recaptcha_template
+ self.terms_template = hs.config.terms_template
+ self.success_template = hs.config.fallback_success_template
+
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
if stagetype == LoginType.RECAPTCHA:
- html = RECAPTCHA_TEMPLATE % {
- "session": session,
- "myurl": "%s/r0/auth/%s/fallback/web"
+ html = self.recaptcha_template.render(
+ session=session,
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- "sitekey": self.hs.config.recaptcha_public_key,
- }
+ sitekey=self.hs.config.recaptcha_public_key,
+ )
elif stagetype == LoginType.TERMS:
- html = TERMS_TEMPLATE % {
- "session": session,
- "terms_url": "%s_matrix/consent?v=%s"
+ html = self.terms_template.render(
+ session=session,
+ terms_url="%s_matrix/consent?v=%s"
% (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
- "myurl": "%s/r0/auth/%s/fallback/web"
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
- }
+ )
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
@@ -222,14 +138,14 @@ class AuthRestServlet(RestServlet):
)
if success:
- html = SUCCESS_TEMPLATE
+ html = self.success_template.render()
else:
- html = RECAPTCHA_TEMPLATE % {
- "session": session,
- "myurl": "%s/r0/auth/%s/fallback/web"
+ html = self.recaptcha_template.render(
+ session=session,
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
- "sitekey": self.hs.config.recaptcha_public_key,
- }
+ sitekey=self.hs.config.recaptcha_public_key,
+ )
elif stagetype == LoginType.TERMS:
authdict = {"session": session}
@@ -238,18 +154,18 @@ class AuthRestServlet(RestServlet):
)
if success:
- html = SUCCESS_TEMPLATE
+ html = self.success_template.render()
else:
- html = TERMS_TEMPLATE % {
- "session": session,
- "terms_url": "%s_matrix/consent?v=%s"
+ html = self.terms_template.render(
+ session=session,
+ terms_url="%s_matrix/consent?v=%s"
% (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- "myurl": "%s/r0/auth/%s/fallback/web"
+ myurl="%s/r0/auth/%s/fallback/web"
% (CLIENT_API_PREFIX, LoginType.TERMS),
- }
+ )
elif stagetype == LoginType.SSO:
# The SSO fallback workflow should not post here,
raise SynapseError(404, "Fallback SSO auth does not support POST requests.")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fe9d019c44..76879ac559 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(CapabilitiesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index c0714fcfb1..7e174de692 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(DevicesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
- super(DeleteDevicesRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(DeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b28da017cd..7cc692643b 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
- super(GetFilterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
- super(CreateFilterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 075afdd32b..75215a3779 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -32,7 +32,7 @@ class GroupServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
- super(GroupServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -66,7 +66,7 @@ class GroupSummaryServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
- super(GroupSummaryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -97,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupSummaryRoomsCatServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -137,7 +137,7 @@ class GroupCategoryServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupCategoryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -181,7 +181,7 @@ class GroupCategoriesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
def __init__(self, hs):
- super(GroupCategoriesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -204,7 +204,7 @@ class GroupRoleServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
def __init__(self, hs):
- super(GroupRoleServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -248,7 +248,7 @@ class GroupRolesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
def __init__(self, hs):
- super(GroupRolesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -279,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupSummaryUsersRoleServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -317,7 +317,7 @@ class GroupRoomServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
- super(GroupRoomServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -343,7 +343,7 @@ class GroupUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
- super(GroupUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -366,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
- super(GroupInvitedUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -389,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
- super(GroupSettingJoinPolicyServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@@ -413,7 +413,7 @@ class GroupCreateServlet(RestServlet):
PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
- super(GroupCreateServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -444,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminRoomsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -481,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminRoomsConfigServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -507,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminUsersInviteServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -536,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet):
)
def __init__(self, hs):
- super(GroupAdminUsersKickServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -585,7 +585,7 @@ class GroupSelfLeaveServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
def __init__(self, hs):
- super(GroupSelfLeaveServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -609,7 +609,7 @@ class GroupSelfJoinServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
def __init__(self, hs):
- super(GroupSelfJoinServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -633,7 +633,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
def __init__(self, hs):
- super(GroupSelfAcceptInviteServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@@ -657,7 +657,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
def __init__(self, hs):
- super(GroupSelfUpdatePublicityServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -680,7 +680,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
def __init__(self, hs):
- super(PublicisedGroupsForUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -701,7 +701,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups$")
def __init__(self, hs):
- super(PublicisedGroupsForUsersServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -725,7 +725,7 @@ class GroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/joined_groups$")
def __init__(self, hs):
- super(GroupsForUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24bb090822..55c4606569 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(KeyUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- super(KeyQueryServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -177,9 +177,10 @@ class KeyChangesServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- super(KeyChangesServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
+ self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet):
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
- from_token = StreamToken.from_string(from_token_string)
+ from_token = await StreamToken.from_string(self.store, from_token_string)
user_id = requester.user.to_string()
@@ -222,7 +223,7 @@ class OneTimeKeyServlet(RestServlet):
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
- super(OneTimeKeyServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -250,7 +251,7 @@ class SigningKeyUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SigningKeyUploadServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -308,7 +309,7 @@ class SignaturesUploadServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SignaturesUploadServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index aa911d75ee..87063ec8b1 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet):
PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
- super(NotificationsServlet, self).__init__()
+ super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 6ae9a5a8e9..5b996e2d63 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet):
EXPIRES_MS = 3600 * 1000
def __init__(self, hs):
- super(IdTokenServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
index 968403cca4..68b27ff23a 100644
--- a/synapse/rest/client/v2_alpha/password_policy.py
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(PasswordPolicyServlet, self).__init__()
+ super().__init__()
self.policy = hs.config.password_policy
self.enabled = hs.config.password_policy_enabled
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 67cbc37312..55c6688f52 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
- super(ReadMarkerRestServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 92555bd4a9..6f7246a394 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet):
)
def __init__(self, hs):
- super(ReceiptRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index c589dd6c78..ec8ef9bf88 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -76,7 +76,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(EmailRegisterRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.config = hs.config
@@ -174,7 +174,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(MsisdnRegisterRequestTokenRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@@ -249,7 +249,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegistrationSubmitTokenServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.config = hs.config
@@ -319,7 +319,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(UsernameAvailabilityRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
@@ -363,7 +363,7 @@ class RegisterRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RegisterRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -431,11 +431,14 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
- if isinstance(desired_username, str):
- result = await self._do_appservice_registration(
- desired_username, access_token, body
- )
- return 200, result # we throw for non 200 responses
+ if not isinstance(desired_username, str):
+ raise SynapseError(400, "Desired Username is missing or not a string")
+
+ result = await self._do_appservice_registration(
+ desired_username, access_token, body
+ )
+
+ return 200, result
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index e29f49f7f5..18c75738f8 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -61,7 +61,7 @@ class RelationSendServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationSendServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
@@ -138,7 +138,7 @@ class RelationPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -233,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationAggregationPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
@@ -311,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
)
def __init__(self, hs):
- super(RelationAggregationGroupPaginationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e15927c4ea..215d619ca1 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
def __init__(self, hs):
- super(ReportEventRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 59529707df..53de97923f 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysNewVersionServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(RoomKeysVersionServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index 39a5518614..bf030e0ff4 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -53,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet):
)
def __init__(self, hs):
- super(RoomUpgradeRestServlet, self).__init__()
+ super().__init__()
self._hs = hs
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index db829f3098..bc4f43639a 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(SendToDeviceRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
index 2492634dac..c866d5151c 100644
--- a/synapse/rest/client/v2_alpha/shared_rooms.py
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -34,7 +34,7 @@ class UserSharedRoomsServlet(RestServlet):
)
def __init__(self, hs):
- super(UserSharedRoomsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a0b00135e1..6779df952f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -74,9 +74,10 @@ class SyncRestServlet(RestServlet):
ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
def __init__(self, hs):
- super(SyncRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet):
device_id=device_id,
)
+ since_token = None
if since is not None:
- since_token = StreamToken.from_string(since)
- else:
- since_token = None
+ since_token = await StreamToken.from_string(self.store, since)
# send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(user.to_string())
@@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "next_batch": sync_result.next_batch.to_string(),
+ "next_batch": await sync_result.next_batch.to_string(self.store),
}
@staticmethod
@@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet):
result = {
"timeline": {
"events": serialized_timeline,
- "prev_batch": room.timeline.prev_batch.to_string(),
+ "prev_batch": await room.timeline.prev_batch.to_string(self.store),
"limited": room.timeline.limited,
},
"state": {"events": serialized_state},
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index a3f12e8a77..bf3a79db44 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -31,7 +31,7 @@ class TagListServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs):
- super(TagListServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -56,7 +56,7 @@ class TagServlet(RestServlet):
)
def __init__(self, hs):
- super(TagServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 23709960ad..0c127a1b5f 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
- super(ThirdPartyProtocolsServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
- super(ThirdPartyProtocolServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
- super(ThirdPartyUserServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
- super(ThirdPartyLocationServlet, self).__init__()
+ super().__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 83f3b6b70a..79317c74ba 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet):
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
- super(TokenRefreshRestServlet, self).__init__()
+ super().__init__()
async def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index bef91a2d3e..ad598cefe0 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -31,7 +31,7 @@ class UserDirectorySearchRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer): server
"""
- super(UserDirectorySearchRestServlet, self).__init__()
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c560edbc59..d24a199318 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -29,7 +29,7 @@ class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
def __init__(self, hs):
- super(VersionsRestServlet, self).__init__()
+ super().__init__()
self.config = hs.config
# Calculate these once since they shouldn't change after start-up.
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 5db7f81c2d..f843f02454 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource):
Supports individual GET APIs and a bulk query POST API.
- Requsts:
+ Requests:
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d2826374a7..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -80,7 +80,7 @@ class MediaFilePaths:
self, server_name, file_id, width, height, content_type, method
):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths:
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+ # Legacy path that was used to store thumbnails previously.
+ # Should be removed after some time, when most of the thumbnails are stored
+ # using the new path.
+ def remote_media_thumbnail_rel_legacy(
+ self, server_name, file_id, width, height, content_type
+ ):
+ top_level_type, sub_type = content_type.split("/")
+ file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ return os.path.join(
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
+ file_name,
+ )
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9a1b7779f7..e1192b47cd 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ from .media_storage import MediaStorage
from .preview_url_resource import PreviewUrlResource
from .storage_provider import StorageProviderWrapper
from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
logger = logging.getLogger(__name__)
@@ -139,7 +139,7 @@ class MediaRepository:
async def create_content(
self,
media_type: str,
- upload_name: str,
+ upload_name: Optional[str],
content: IO,
content_length: int,
auth_user: str,
@@ -147,8 +147,8 @@ class MediaRepository:
"""Store uploaded content for a local user and return the mxc URL
Args:
- media_type: The content type of the file
- upload_name: The name of the file
+ media_type: The content type of the file.
+ upload_name: The name of the file, if provided.
content: A file like object that is the content to store
content_length: The length of the content
auth_user: The user_id of the uploader
@@ -156,6 +156,7 @@ class MediaRepository:
Returns:
The mxc url of the stored content
"""
+
media_id = random_string(24)
file_info = FileInfo(server_name=None, file_id=media_id)
@@ -460,13 +461,30 @@ class MediaRepository:
return t_byte_source
async def generate_local_exact_thumbnail(
- self, media_id, t_width, t_height, t_method, t_type, url_cache
- ):
+ self,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ url_cache: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+ media_id,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -506,14 +524,36 @@ class MediaRepository:
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def generate_remote_exact_thumbnail(
- self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+ media_id,
+ server_name,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -559,6 +599,9 @@ class MediaRepository:
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def _generate_thumbnails(
self,
server_name: Optional[str],
@@ -590,7 +633,18 @@ class MediaRepository:
FileInfo(server_name, file_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate thumbnails for remote media %s from %s of type %s: %s",
+ media_id,
+ server_name,
+ media_type,
+ e,
+ )
+ return None
+
m_width = thumbnailer.width
m_height = thumbnailer.height
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3a352b5631..a9586fb0b7 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -141,17 +141,34 @@ class MediaStorage:
Returns:
Returns a Responder if the file was found, otherwise None.
"""
+ paths = [self._file_info_to_path(file_info)]
- path = self._file_info_to_path(file_info)
- local_path = os.path.join(self.local_media_directory, path)
- if os.path.exists(local_path):
- return FileResponder(open(local_path, "rb"))
+ # fallback for remote thumbnails with no method in the filename
+ if file_info.thumbnail and file_info.server_name:
+ paths.append(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ )
+
+ for path in paths:
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ logger.debug("responding with local file %s", local_path)
+ return FileResponder(open(local_path, "rb"))
+ logger.debug("local file %s did not exist", local_path)
for provider in self.storage_providers:
- res = await provider.fetch(path, file_info) # type: Any
- if res:
- logger.debug("Streaming %s from %s", path, provider)
- return res
+ for path in paths:
+ res = await provider.fetch(path, file_info) # type: Any
+ if res:
+ logger.debug("Streaming %s from %s", path, provider)
+ return res
+ logger.debug("%s not found on %s", path, provider)
return None
@@ -170,6 +187,20 @@ class MediaStorage:
if os.path.exists(local_path):
return local_path
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return legacy_local_path
+
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index cd8c246594..dce6c4d168 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -102,7 +102,7 @@ for endpoint, globs in _oembed_globs.items():
_oembed_patterns[re.compile(pattern)] = endpoint
-@attr.s
+@attr.s(slots=True)
class OEmbedResult:
# Either HTML content or URL must be provided.
html = attr.ib(type=Optional[str])
@@ -450,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url, user):
+ async def _download_url(self, url: str, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
- url_to_download = url
+ url_to_download = url # type: Optional[str]
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@@ -520,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
expires = ONE_HOUR
- etag = headers["ETag"][0] if "ETag" in headers else None
+ etag = (
+ headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None
+ )
else:
- html_bytes = oembed_result.html.encode("utf-8") # type: ignore
+ # we can only get here if we did an oembed request and have an oembed_result.html
+ assert oembed_result.html is not None
+ assert oembed_url is not None
+
+ html_bytes = oembed_result.html.encode("utf-8")
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
f.write(html_bytes)
await finish()
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _select_or_generate_remote_thumbnail(
self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index d681bf7bf0..32a8e4f960 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
import logging
from io import BytesIO
-from PIL import Image as Image
+from PIL import Image
logger = logging.getLogger(__name__)
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
}
+class ThumbnailError(Exception):
+ """An error occurred generating a thumbnail."""
+
+
class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
- self.image = Image.open(input_path)
+ try:
+ self.image = Image.open(input_path)
+ except OSError as e:
+ # If an error occurs opening the image, a thumbnail won't be able to
+ # be generated.
+ raise ThumbnailError from e
+
self.width, self.height = self.image.size
self.transpose_method = None
try:
@@ -73,7 +83,7 @@ class Thumbnailer:
Args:
max_width: The largest possible width.
- max_height: The larget possible height.
+ max_height: The largest possible height.
"""
if max_width * self.height < max_height * self.width:
@@ -107,7 +117,7 @@ class Thumbnailer:
Args:
max_width: The largest possible width.
- max_height: The larget possible height.
+ max_height: The largest possible height.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 3ebf7a68e6..d76f7389e1 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -63,6 +63,10 @@ class UploadResource(DirectServeJsonResource):
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
)
+ # If the name is falsey (e.g. an empty byte string) ensure it is None.
+ else:
+ upload_name = None
+
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index c10188a5d7..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.python import failure
-from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeHtmlResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource
class SAML2ResponseResource(DirectServeHtmlResource):
@@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource):
def __init__(self, hs):
super().__init__()
self._saml_handler = hs.get_saml_handler()
- self._error_html_template = hs.config.saml2.saml2_error_html_template
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
- f = failure.Failure(
- SynapseError(400, "Unexpected GET request on /saml2/authn_response")
+ self._saml_handler._render_error(
+ request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
- return_html_error(f, request, self._error_html_template)
async def _async_render_POST(self, request):
- try:
- await self._saml_handler.handle_saml_response(request)
- except Exception:
- f = failure.Failure()
- return_html_error(f, request, self._error_html_template)
+ await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
new file mode 100644
index 0000000000..c0b733488b
--- /dev/null
+++ b/synapse/rest/synapse/client/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
new file mode 100644
index 0000000000..9e4fbc0cbd
--- /dev/null
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.errors import ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.server import DirectServeHtmlResource
+from synapse.http.servlet import parse_string
+from synapse.util.stringutils import assert_valid_client_secret
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
+ """Handles 3PID validation token submission
+
+ This resource gets mounted under /_synapse/client/password_reset/email/submit_token
+ """
+
+ isLeaf = 1
+
+ def __init__(self, hs: "HomeServer"):
+ """
+ Args:
+ hs: server
+ """
+ super().__init__()
+
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ self._local_threepid_handling_disabled_due_to_email_config = (
+ hs.config.local_threepid_handling_disabled_due_to_email_config
+ )
+ self._confirmation_email_template = (
+ hs.config.email_password_reset_template_confirmation_html
+ )
+ self._email_password_reset_template_success_html = (
+ hs.config.email_password_reset_template_success_html_content
+ )
+ self._failure_email_template = (
+ hs.config.email_password_reset_template_failure_html
+ )
+
+ # This resource should not be mounted if threepid behaviour is not LOCAL
+ assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+
+ async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
+
+ # Show a confirmation page, just in case someone accidentally clicked this link when
+ # they didn't mean to
+ template_vars = {
+ "sid": sid,
+ "token": token,
+ "client_secret": client_secret,
+ }
+ return (
+ 200,
+ self._confirmation_email_template.render(**template_vars).encode("utf-8"),
+ )
+
+ async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]:
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warning(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ next_link_bytes = next_link.encode("utf-8")
+ request.setHeader("Location", next_link_bytes)
+ return (
+ 302,
+ (
+ b'You are being redirected to <a src="%s">%s</a>.'
+ % (next_link_bytes, next_link_bytes)
+ ),
+ )
+
+ # Otherwise show the success template
+ html_bytes = self._email_password_reset_template_success_html.encode(
+ "utf-8"
+ )
+ status_code = 200
+ except ThreepidValidationError as e:
+ status_code = e.code
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html_bytes = self._failure_email_template.render(**template_vars).encode(
+ "utf-8"
+ )
+
+ return status_code, html_bytes
|