diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 1d7c11b42d..1af8d99d20 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -100,8 +100,7 @@ class ClientRestResource(JsonResource):
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
- if is_main_process:
- directory.register_servlets(hs, client_resource)
+ directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
if is_main_process:
pusher.register_servlets(hs, client_resource)
@@ -134,8 +133,8 @@ class ClientRestResource(JsonResource):
if is_main_process:
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
room_batch.register_servlets(hs, client_resource)
+ capabilities.register_servlets(hs, client_resource)
if is_main_process:
- capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 79f22a59f1..c729364839 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -39,6 +39,7 @@ from synapse.rest.admin.event_reports import (
EventReportDetailRestServlet,
EventReportsRestServlet,
)
+from synapse.rest.admin.experimental_features import ExperimentalFeaturesRestServlet
from synapse.rest.admin.federation import (
DestinationMembershipRestServlet,
DestinationResetConnectionRestServlet,
@@ -68,7 +69,10 @@ from synapse.rest.admin.rooms import (
RoomTimestampToEventRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
-from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
+from synapse.rest.admin.statistics import (
+ LargestRoomsStatistics,
+ UserMediaStatisticsRestServlet,
+)
from synapse.rest.admin.username_available import UsernameAvailableRestServlet
from synapse.rest.admin.users import (
AccountDataRestServlet,
@@ -259,6 +263,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server)
+ LargestRoomsStatistics(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
AccountDataRestServlet(hs).register(http_server)
@@ -288,6 +293,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateEnabledRestServlet(hs).register(http_server)
BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server)
+ ExperimentalFeaturesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 3b2f2d9abb..11ebed9bfd 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -137,6 +137,35 @@ class DevicesRestServlet(RestServlet):
devices = await self.device_handler.get_devices_by_user(target_user.to_string())
return HTTPStatus.OK, {"devices": devices, "total": len(devices)}
+ async def on_POST(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ """Creates a new device for the user."""
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.is_mine(target_user):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "Can only create devices for local users"
+ )
+
+ u = await self.store.get_user_by_id(target_user.to_string())
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ body = parse_json_object_from_request(request)
+ device_id = body.get("device_id")
+ if not device_id:
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Missing device_id")
+ if not isinstance(device_id, str):
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "device_id must be a string")
+
+ await self.device_handler.check_device_registered(
+ user_id=user_id, device_id=device_id
+ )
+
+ return HTTPStatus.CREATED, {}
+
class DeleteDevicesRestServlet(RestServlet):
"""
diff --git a/synapse/rest/admin/experimental_features.py b/synapse/rest/admin/experimental_features.py
new file mode 100644
index 0000000000..abf273af10
--- /dev/null
+++ b/synapse/rest/admin/experimental_features.py
@@ -0,0 +1,118 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from enum import Enum
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Dict, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin import admin_patterns, assert_requester_is_admin
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+class ExperimentalFeature(str, Enum):
+ """
+ Currently supported per-user features
+ """
+
+ MSC3026 = "msc3026"
+ MSC3881 = "msc3881"
+ MSC3967 = "msc3967"
+
+
+class ExperimentalFeaturesRestServlet(RestServlet):
+ """
+ Enable or disable experimental features for a user or determine which features are enabled
+ for a given user
+ """
+
+ PATTERNS = admin_patterns("/experimental_features/(?P<user_id>[^/]*)")
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self.is_mine = hs.is_mine
+
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ user_id: str,
+ ) -> Tuple[int, JsonDict]:
+ """
+ List which features are enabled for a given user
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.is_mine(target_user):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "User must be local to check what experimental features are enabled.",
+ )
+
+ enabled_features = await self.store.list_enabled_features(user_id)
+
+ user_features = {}
+ for feature in ExperimentalFeature:
+ if feature in enabled_features:
+ user_features[feature] = True
+ else:
+ user_features[feature] = False
+ return HTTPStatus.OK, {"features": user_features}
+
+ async def on_PUT(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[HTTPStatus, Dict]:
+ """
+ Enable or disable the provided features for the requester
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ body = parse_json_object_from_request(request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.is_mine(target_user):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "User must be local to enable experimental features.",
+ )
+
+ features = body.get("features")
+ if not features:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST, "You must provide features to set."
+ )
+
+ # validate the provided features
+ validated_features = {}
+ for feature, enabled in features.items():
+ try:
+ validated_feature = ExperimentalFeature(feature)
+ validated_features[validated_feature] = enabled
+ except ValueError:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ f"{feature!r} is not recognised as a valid experimental feature.",
+ )
+
+ await self.store.set_features_for_user(user_id, validated_features)
+
+ return HTTPStatus.OK, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c134ccfb3d..b7637dff0b 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -258,7 +258,7 @@ class DeleteMediaByID(RestServlet):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
- self.server_name = hs.hostname
+ self._is_mine_server_name = hs.is_mine_server_name
self.media_repository = hs.get_media_repository()
async def on_DELETE(
@@ -266,7 +266,7 @@ class DeleteMediaByID(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if self.server_name != server_name:
+ if not self._is_mine_server_name(server_name):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
if await self.store.get_local_media(media_id) is None:
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 4de56bf13f..1d65560265 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -70,7 +70,7 @@ class RoomRestV2Servlet(RestServlet):
self._auth = hs.get_auth()
self._store = hs.get_datastores().main
self._pagination_handler = hs.get_pagination_handler()
- self._third_party_rules = hs.get_third_party_event_rules()
+ self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
async def on_DELETE(
self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 9c45f4650d..19780e4b4c 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -113,3 +113,28 @@ class UserMediaStatisticsRestServlet(RestServlet):
ret["next_token"] = start + len(users_media)
return HTTPStatus.OK, ret
+
+
+class LargestRoomsStatistics(RestServlet):
+ """Get the largest rooms by database size.
+
+ Only works when using PostgreSQL.
+ """
+
+ PATTERNS = admin_patterns("/statistics/database/rooms$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.auth = hs.get_auth()
+ self.stats_controller = hs.get_storage_controllers().stats
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ room_sizes = await self.stats_controller.get_room_db_size_estimate()
+
+ return HTTPStatus.OK, {
+ "rooms": [
+ {"room_id": room_id, "estimated_size": size}
+ for room_id, size in room_sizes
+ ]
+ }
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 331f225116..932333ae57 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -336,7 +336,7 @@ class UserRestServletV2(RestServlet):
HTTPStatus.CONFLICT, "External id is already in use."
)
- if "avatar_url" in body and isinstance(body["avatar_url"], str):
+ if "avatar_url" in body:
await self.profile_handler.set_avatar_url(
target_user, requester, body["avatar_url"], True
)
diff --git a/synapse/rest/client/account_data.py b/synapse/rest/client/account_data.py
index 43193ad086..b1f9e9dc9b 100644
--- a/synapse/rest/client/account_data.py
+++ b/synapse/rest/client/account_data.py
@@ -13,8 +13,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.api.constants import AccountDataTypes, ReceiptTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -29,6 +30,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+def _check_can_set_account_data_type(account_data_type: str) -> None:
+ """The fully read marker and push rules cannot be directly set via /account_data."""
+ if account_data_type == ReceiptTypes.FULLY_READ:
+ raise SynapseError(
+ 405,
+ "Cannot set m.fully_read through this API."
+ " Use /rooms/!roomId:server.name/read_markers",
+ Codes.BAD_JSON,
+ )
+ elif account_data_type == AccountDataTypes.PUSH_RULES:
+ raise SynapseError(
+ 405,
+ "Cannot set m.push_rules through this API. Use /pushrules",
+ Codes.BAD_JSON,
+ )
+
+
class AccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
@@ -46,6 +64,7 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.handler = hs.get_account_data_handler()
+ self._push_rules_handler = hs.get_push_rules_handler()
async def on_PUT(
self, request: SynapseRequest, user_id: str, account_data_type: str
@@ -54,6 +73,10 @@ class AccountDataServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
+ # Raise an error if the account data type cannot be set directly.
+ if self._hs.config.experimental.msc4010_push_rules_account_data:
+ _check_can_set_account_data_type(account_data_type)
+
body = parse_json_object_from_request(request)
# If experimental support for MSC3391 is enabled, then providing an empty dict
@@ -77,19 +100,28 @@ class AccountDataServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = await self.store.get_global_account_data_by_type_for_user(
- user_id, account_data_type
- )
+ # Push rules are stored in a separate table and must be queried separately.
+ if (
+ self._hs.config.experimental.msc4010_push_rules_account_data
+ and account_data_type == AccountDataTypes.PUSH_RULES
+ ):
+ account_data: Optional[
+ JsonDict
+ ] = await self._push_rules_handler.push_rules_for_user(requester.user)
+ else:
+ account_data = await self.store.get_global_account_data_by_type_for_user(
+ user_id, account_data_type
+ )
- if event is None:
+ if account_data is None:
raise NotFoundError("Account data not found")
# If experimental support for MSC3391 is enabled, then this endpoint should
# return a 404 if the content for an account data type is an empty dict.
- if self._hs.config.experimental.msc3391_enabled and event == {}:
+ if self._hs.config.experimental.msc3391_enabled and account_data == {}:
raise NotFoundError("Account data not found")
- return 200, event
+ return 200, account_data
class UnstableAccountDataServlet(RestServlet):
@@ -108,6 +140,7 @@ class UnstableAccountDataServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
+ self._hs = hs
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
@@ -121,6 +154,10 @@ class UnstableAccountDataServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot delete account data for other users.")
+ # Raise an error if the account data type cannot be set directly.
+ if self._hs.config.experimental.msc4010_push_rules_account_data:
+ _check_can_set_account_data_type(account_data_type)
+
await self.handler.remove_account_data_for_user(user_id, account_data_type)
return 200, {}
@@ -164,9 +201,10 @@ class RoomAccountDataServlet(RestServlet):
Codes.INVALID_PARAM,
)
- body = parse_json_object_from_request(request)
-
- if account_data_type == "m.fully_read":
+ # Raise an error if the account data type cannot be set directly.
+ if self._hs.config.experimental.msc4010_push_rules_account_data:
+ _check_can_set_account_data_type(account_data_type)
+ elif account_data_type == ReceiptTypes.FULLY_READ:
raise SynapseError(
405,
"Cannot set m.fully_read through this API."
@@ -174,6 +212,8 @@ class RoomAccountDataServlet(RestServlet):
Codes.BAD_JSON,
)
+ body = parse_json_object_from_request(request)
+
# If experimental support for MSC3391 is enabled, then providing an empty dict
# as the value for an account data type should be functionally equivalent to
# calling the DELETE method on the same type.
@@ -208,19 +248,26 @@ class RoomAccountDataServlet(RestServlet):
Codes.INVALID_PARAM,
)
- event = await self.store.get_account_data_for_room_and_type(
- user_id, room_id, account_data_type
- )
+ # Room-specific push rules are not currently supported.
+ if (
+ self._hs.config.experimental.msc4010_push_rules_account_data
+ and account_data_type == AccountDataTypes.PUSH_RULES
+ ):
+ account_data: Optional[JsonDict] = {}
+ else:
+ account_data = await self.store.get_account_data_for_room_and_type(
+ user_id, room_id, account_data_type
+ )
- if event is None:
+ if account_data is None:
raise NotFoundError("Room account data not found")
# If experimental support for MSC3391 is enabled, then this endpoint should
# return a 404 if the content for an account data type is an empty dict.
- if self._hs.config.experimental.msc3391_enabled and event == {}:
+ if self._hs.config.experimental.msc3391_enabled and account_data == {}:
raise NotFoundError("Room account data not found")
- return 200, event
+ return 200, account_data
class UnstableRoomAccountDataServlet(RestServlet):
@@ -240,6 +287,7 @@ class UnstableRoomAccountDataServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
+ self._hs = hs
self.auth = hs.get_auth()
self.handler = hs.get_account_data_handler()
@@ -261,6 +309,10 @@ class UnstableRoomAccountDataServlet(RestServlet):
Codes.INVALID_PARAM,
)
+ # Raise an error if the account data type cannot be set directly.
+ if self._hs.config.experimental.msc4010_push_rules_account_data:
+ _check_can_set_account_data_type(account_data_type)
+
await self.handler.remove_account_data_for_room(
user_id, room_id, account_data_type
)
diff --git a/synapse/rest/client/appservice_ping.py b/synapse/rest/client/appservice_ping.py
index 31466a4ad4..3f553d14d1 100644
--- a/synapse/rest/client/appservice_ping.py
+++ b/synapse/rest/client/appservice_ping.py
@@ -39,9 +39,8 @@ logger = logging.getLogger(__name__)
class AppservicePingRestServlet(RestServlet):
PATTERNS = client_patterns(
- "/fi.mau.msc2659/appservice/(?P<appservice_id>[^/]*)/ping",
- unstable=True,
- releases=(),
+ "/appservice/(?P<appservice_id>[^/]*)/ping",
+ releases=("v1",),
)
def __init__(self, hs: "HomeServer"):
@@ -107,9 +106,8 @@ class AppservicePingRestServlet(RestServlet):
duration = time.monotonic() - start
- return HTTPStatus.OK, {"duration": int(duration * 1000)}
+ return HTTPStatus.OK, {"duration_ms": int(duration * 1000)}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
- if hs.config.experimental.msc2659_enabled:
- AppservicePingRestServlet(hs).register(http_server)
+ AppservicePingRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 11fc0b0678..a77b0697b7 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -33,6 +33,7 @@ class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server."""
PATTERNS = client_patterns("/capabilities$")
+ CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index f17b4c8d22..570bb52747 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -39,12 +39,14 @@ logger = logging.getLogger(__name__)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ClientDirectoryServer(hs).register(http_server)
- ClientDirectoryListServer(hs).register(http_server)
- ClientAppserviceDirectoryListServer(hs).register(http_server)
+ if hs.config.worker.worker_app is None:
+ ClientDirectoryListServer(hs).register(http_server)
+ ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(RestServlet):
PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
+ CATEGORY = "Client API requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index ab7d8c9419..04561f36d7 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -94,7 +94,7 @@ class CreateFilterRestServlet(RestServlet):
set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit)
filter_id = await self.filtering.add_user_filter(
- user_localpart=target_user.localpart, user_filter=content
+ user_id=target_user, user_filter=content
)
return 200, {"filter_id": str(filter_id)}
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index 6209b79b01..9bbab5e624 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -15,7 +15,9 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+import re
+from collections import Counter
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from synapse.api.errors import InvalidAPICallError, SynapseError
from synapse.http.server import HttpServer
@@ -288,7 +290,64 @@ class OneTimeKeyServlet(RestServlet):
await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+
+ # Generate a count for each algorithm, which is hard-coded to 1.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithm in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = {algorithm: 1}
+
+ result = await self.e2e_keys_handler.claim_one_time_keys(
+ query, timeout, always_include_fallback_keys=False
+ )
+ return 200, result
+
+
+class UnstableOneTimeKeyServlet(RestServlet):
+ """
+ Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
+ querying for multiple OTKs at once and always includes fallback keys in the
+ response.
+
+ POST /keys/claim HTTP/1.1
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": ["<algorithm>", ...]
+ } } }
+
+ HTTP/1.1 200 OK
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
+ """
+
+ PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
+ CATEGORY = "Encryption requests"
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.auth.get_user_by_req(request, allow_guest=True)
+ timeout = parse_integer(request, "timeout", 10 * 1000)
+ body = parse_json_object_from_request(request)
+
+ # Generate a count for each algorithm.
+ query: Dict[str, Dict[str, Dict[str, int]]] = {}
+ for user_id, one_time_keys in body.get("one_time_keys", {}).items():
+ for device_id, algorithms in one_time_keys.items():
+ query.setdefault(user_id, {})[device_id] = Counter(algorithms)
+
+ result = await self.e2e_keys_handler.claim_one_time_keys(
+ query, timeout, always_include_fallback_keys=True
+ )
return 200, result
@@ -394,6 +453,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
+ if hs.config.experimental.msc3983_appservice_otk_claims:
+ UnstableOneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index e04d4f2425..7d7714abb8 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
LoginError,
NotApprovedError,
SynapseError,
+ UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
@@ -84,14 +85,10 @@ class LoginRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
+ self._main_store = hs.get_datastores().main
# JWT configuration variables.
self.jwt_enabled = hs.config.jwt.jwt_enabled
- self.jwt_secret = hs.config.jwt.jwt_secret
- self.jwt_subject_claim = hs.config.jwt.jwt_subject_claim
- self.jwt_algorithm = hs.config.jwt.jwt_algorithm
- self.jwt_issuer = hs.config.jwt.jwt_issuer
- self.jwt_audiences = hs.config.jwt.jwt_audiences
# SSO configuration.
self.saml2_enabled = hs.config.saml2.saml2_enabled
@@ -120,13 +117,13 @@ class LoginRestServlet(RestServlet):
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
- store=hs.get_datastores().main,
+ store=self._main_store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
- store=hs.get_datastores().main,
+ store=self._main_store,
clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second,
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
@@ -292,6 +289,9 @@ class LoginRestServlet(RestServlet):
login_submission,
ratelimit=appservice.is_rate_limited(),
should_issue_refresh_token=should_issue_refresh_token,
+ # The user represented by an appservice's configured sender_localpart
+ # is not actually created in Synapse.
+ should_check_deactivated=qualified_user_id != appservice.sender,
)
async def _do_other_login(
@@ -338,6 +338,7 @@ class LoginRestServlet(RestServlet):
auth_provider_id: Optional[str] = None,
should_issue_refresh_token: bool = False,
auth_provider_session_id: Optional[str] = None,
+ should_check_deactivated: bool = True,
) -> LoginResponse:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -357,6 +358,11 @@ class LoginRestServlet(RestServlet):
should_issue_refresh_token: True if this login should issue
a refresh token alongside the access token.
auth_provider_session_id: The session ID got during login from the SSO IdP.
+ should_check_deactivated: True if the user should be checked for
+ deactivation status before logging in.
+
+ This exists purely for appservice's configured sender_localpart
+ which doesn't have an associated user in the database.
Returns:
Dictionary of account information after successful login.
@@ -376,6 +382,12 @@ class LoginRestServlet(RestServlet):
)
user_id = canonical_uid
+ # If the account has been deactivated, do not proceed with the login.
+ if should_check_deactivated:
+ deactivated = await self._main_store.get_user_deactivated_status(user_id)
+ if deactivated:
+ raise UserDeactivatedError("This account has been deactivated")
+
device_id = login_submission.get("device_id")
# If device_id is present, check that device_id is not longer than a reasonable 512 characters
@@ -434,7 +446,7 @@ class LoginRestServlet(RestServlet):
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
"""
- Handle the final stage of SSO login.
+ Handle token login.
Args:
login_submission: The JSON request body.
@@ -459,72 +471,24 @@ class LoginRestServlet(RestServlet):
async def _do_jwt_login(
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
) -> LoginResponse:
- token = login_submission.get("token", None)
- if token is None:
- raise LoginError(
- 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
- )
-
- from authlib.jose import JsonWebToken, JWTClaims
- from authlib.jose.errors import BadSignatureError, InvalidClaimError, JoseError
-
- jwt = JsonWebToken([self.jwt_algorithm])
- claim_options = {}
- if self.jwt_issuer is not None:
- claim_options["iss"] = {"value": self.jwt_issuer, "essential": True}
- if self.jwt_audiences is not None:
- claim_options["aud"] = {"values": self.jwt_audiences, "essential": True}
-
- try:
- claims = jwt.decode(
- token,
- key=self.jwt_secret,
- claims_cls=JWTClaims,
- claims_options=claim_options,
- )
- except BadSignatureError:
- # We handle this case separately to provide a better error message
- raise LoginError(
- 403,
- "JWT validation failed: Signature verification failed",
- errcode=Codes.FORBIDDEN,
- )
- except JoseError as e:
- # A JWT error occurred, return some info back to the client.
- raise LoginError(
- 403,
- "JWT validation failed: %s" % (str(e),),
- errcode=Codes.FORBIDDEN,
- )
-
- try:
- claims.validate(leeway=120) # allows 2 min of clock skew
-
- # Enforce the old behavior which is rolled out in productive
- # servers: if the JWT contains an 'aud' claim but none is
- # configured, the login attempt will fail
- if claims.get("aud") is not None:
- if self.jwt_audiences is None or len(self.jwt_audiences) == 0:
- raise InvalidClaimError("aud")
- except JoseError as e:
- raise LoginError(
- 403,
- "JWT validation failed: %s" % (str(e),),
- errcode=Codes.FORBIDDEN,
- )
+ """
+ Handle the custom JWT login.
- user = claims.get(self.jwt_subject_claim, None)
- if user is None:
- raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
+ Args:
+ login_submission: The JSON request body.
+ should_issue_refresh_token: True if this login should issue
+ a refresh token alongside the access token.
- user_id = UserID(user, self.hs.hostname).to_string()
- result = await self._complete_login(
+ Returns:
+ The body of the JSON response.
+ """
+ user_id = self.hs.get_jwt_handler().validate_login(login_submission)
+ return await self._complete_login(
user_id,
login_submission,
create_non_existent_users=True,
should_issue_refresh_token=should_issue_refresh_token,
)
- return result
def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
@@ -677,9 +641,17 @@ class CasTicketServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
LoginRestServlet(hs).register(http_server)
- if hs.config.registration.refreshable_access_token_lifetime is not None:
+ if (
+ hs.config.worker.worker_app is None
+ and hs.config.registration.refreshable_access_token_lifetime is not None
+ ):
RefreshTokenServlet(hs).register(http_server)
- SsoRedirectServlet(hs).register(http_server)
+ if (
+ hs.config.cas.cas_enabled
+ or hs.config.saml2.saml2_enabled
+ or hs.config.oidc.oidc_enabled
+ ):
+ SsoRedirectServlet(hs).register(http_server)
if hs.config.cas.cas_enabled:
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py
index 38ef4e459f..c99445da30 100644
--- a/synapse/rest/client/mutual_rooms.py
+++ b/synapse/rest/client/mutual_rooms.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_strings_from_args
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict
from ._base import client_patterns
@@ -30,11 +31,11 @@ logger = logging.getLogger(__name__)
class UserMutualRoomsServlet(RestServlet):
"""
- GET /uk.half-shot.msc2666/user/mutual_rooms/{user_id} HTTP/1.1
+ GET /uk.half-shot.msc2666/user/mutual_rooms?user_id={user_id} HTTP/1.1
"""
PATTERNS = client_patterns(
- "/uk.half-shot.msc2666/user/mutual_rooms/(?P<user_id>[^/]*)",
+ "/uk.half-shot.msc2666/user/mutual_rooms$",
releases=(), # This is an unstable feature
)
@@ -43,17 +44,35 @@ class UserMutualRoomsServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- async def on_GET(
- self, request: SynapseRequest, user_id: str
- ) -> Tuple[int, JsonDict]:
- UserID.from_string(user_id)
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
+ args: Dict[bytes, List[bytes]] = request.args # type: ignore
+
+ user_ids = parse_strings_from_args(args, "user_id", required=True)
+
+ if len(user_ids) > 1:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Duplicate user_id query parameter",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ # We don't do batching, so a batch token is illegal by default
+ if b"batch_token" in args:
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Unknown batch_token",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ user_id = user_ids[0]
requester = await self.auth.get_user_by_req(request)
if user_id == requester.user.to_string():
raise SynapseError(
- code=400,
- msg="You cannot request a list of shared rooms with yourself",
- errcode=Codes.FORBIDDEN,
+ HTTPStatus.UNPROCESSABLE_ENTITY,
+ "You cannot request a list of shared rooms with yourself",
+ errcode=Codes.INVALID_PARAM,
)
rooms = await self.store.get_mutual_rooms_between_users(
diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py
index 1147b6f8ec..5c9fece3ba 100644
--- a/synapse/rest/client/push_rule.py
+++ b/synapse/rest/client/push_rule.py
@@ -28,7 +28,6 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
@@ -146,14 +145,12 @@ class PushRuleRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester.user.to_string()
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
- rules_raw = await self.store.get_push_rules_for_user(user_id)
-
- rules = format_push_rules_for_user(requester.user, rules_raw)
+ rules = await self._push_rules_handler.push_rules_for_user(requester.user)
path_parts = path.split("/")[1:]
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index b8b296bc0c..785dfa08d8 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.storage.databases.main.relations import ThreadsNextBatch
@@ -49,6 +49,7 @@ class RelationPaginationServlet(RestServlet):
self.auth = hs.get_auth()
self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
+ self._support_recurse = hs.config.experimental.msc3981_recurse_relations
async def on_GET(
self,
@@ -63,6 +64,12 @@ class RelationPaginationServlet(RestServlet):
pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
+ if self._support_recurse:
+ recurse = parse_boolean(
+ request, "org.matrix.msc3981.recurse", default=False
+ )
+ else:
+ recurse = False
# The unstable version of this API returns an extra field for client
# compatibility, see https://github.com/matrix-org/synapse/issues/12930.
@@ -75,6 +82,7 @@ class RelationPaginationServlet(RestServlet):
event_id=parent_id,
room_id=room_id,
pagin_config=pagination_config,
+ recurse=recurse,
include_original_event=include_original_event,
relation_type=relation_type,
event_type=event_type,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c0705d4291..951bd033f5 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -501,7 +501,7 @@ class PublicRoomListRestServlet(RestServlet):
limit = None
handler = self.hs.get_room_list_handler()
- if server and server != self.hs.config.server.server_name:
+ if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
@@ -551,7 +551,7 @@ class PublicRoomListRestServlet(RestServlet):
limit = None
handler = self.hs.get_room_list_handler()
- if server and server != self.hs.config.server.server_name:
+ if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
@@ -1096,6 +1096,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
+ self._store = hs.get_datastores().main
self._relation_handler = hs.get_relations_handler()
self._msc3912_enabled = hs.config.experimental.msc3912_enabled
@@ -1113,6 +1114,19 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
+ # Ensure the redacts property in the content matches the one provided in
+ # the URL.
+ room_version = await self._store.get_room_version(room_id)
+ if room_version.msc2176_redaction_rules:
+ if "redacts" in content and content["redacts"] != event_id:
+ raise SynapseError(
+ 400,
+ "Cannot provide a redacts value incoherent with the event_id of the URL parameter",
+ Codes.INVALID_PARAM,
+ )
+ else:
+ content["redacts"] = event_id
+
try:
with_relations = None
if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content:
@@ -1128,20 +1142,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester, txn_id, room_id
)
+ # Event is not yet redacted, create a new event to redact it.
if event is None:
+ event_dict = {
+ "type": EventTypes.Redaction,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+ # Earlier room versions had a top-level redacts property.
+ if not room_version.msc2176_redaction_rules:
+ event_dict["redacts"] = event_id
+
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Redaction,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "redacts": event_id,
- },
- txn_id=txn_id,
+ requester, event_dict, txn_id=txn_id
)
if with_relations:
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index f2aaab6227..0d8a63d8be 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -50,6 +50,8 @@ class HttpTransactionCache:
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
+ self._msc3970_enabled = hs.config.experimental.msc3970_enabled
+
def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
@@ -58,6 +60,7 @@ class HttpTransactionCache:
requests to the same endpoint. The key is formed from the HTTP request
path and attributes from the requester: the access_token_id for regular users,
the user ID for guest users, and the appservice ID for appservice users.
+ With MSC3970, for regular users, the key is based on the user ID and device ID.
Args:
request: The incoming request.
@@ -67,11 +70,21 @@ class HttpTransactionCache:
"""
assert request.path is not None
path: str = request.path.decode("utf8")
+
if requester.is_guest:
assert requester.user is not None, "Guest requester must have a user ID set"
return (path, "guest", requester.user)
+
elif requester.app_service is not None:
return (path, "appservice", requester.app_service.id)
+
+ # With MSC3970, we use the user ID and device ID as the transaction key
+ elif self._msc3970_enabled:
+ assert requester.user, "Requester must have a user"
+ assert requester.device_id, "Requester must have a device_id"
+ return (path, "user", requester.user, requester.device_id)
+
+ # Otherwise, the pre-MSC3970 behaviour is to use the access token ID
else:
assert (
requester.access_token_id is not None
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index ecd84f435f..1eb11081a0 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -79,6 +79,7 @@ class VersionsRestServlet(RestServlet):
"v1.3",
"v1.4",
"v1.5",
+ "v1.6",
],
# as per MSC1497:
"unstable_features": {
@@ -90,7 +91,7 @@ class VersionsRestServlet(RestServlet):
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
# Implements additional endpoints as described in MSC2666
- "uk.half-shot.msc2666.mutual_rooms": True,
+ "uk.half-shot.msc2666.query_mutual_rooms": True,
# Whether new rooms will be set to encrypted or not (based on presets).
"io.element.e2ee_forced.public": self.e2ee_forced_public,
"io.element.e2ee_forced.private": self.e2ee_forced_private,
@@ -111,7 +112,7 @@ class VersionsRestServlet(RestServlet):
# Allows moderators to fetch redacted event content as described in MSC2815
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds a ping endpoint for appservices to check HS->AS connection
- "fi.mau.msc2659": self.config.experimental.msc2659_enabled,
+ "fi.mau.msc2659.stable": True, # TODO: remove when "v1.7" is added above
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
# Adds support for filtering /messages by event relation.
@@ -123,6 +124,10 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3912": self.config.experimental.msc3912_enabled,
# Adds support for unstable "intentional mentions" behaviour.
"org.matrix.msc3952_intentional_mentions": self.config.experimental.msc3952_intentional_mentions,
+ # Whether recursively provide relations is supported.
+ "org.matrix.msc3981": self.config.experimental.msc3981_recurse_relations,
+ # Adds support for deleting account data.
+ "org.matrix.msc3391": self.config.experimental.msc3391_enabled,
},
},
)
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index d03e728d42..22e7bf9d86 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -34,6 +34,8 @@ class LocalKey(RestServlet):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
+ GET /_matrix/key/v2/server HTTP/1.1
+
GET /_matrix/key/v2/server/a.key.id HTTP/1.1
HTTP/1.1 200 OK
@@ -100,6 +102,15 @@ class LocalKey(RestServlet):
def on_GET(
self, request: Request, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
+ # Matrix 1.6 drops support for passing the key_id, this is incompatible
+ # with earlier versions and is allowed in order to support both.
+ # A warning is issued to help determine when it is safe to drop this.
+ if key_id:
+ logger.warning(
+ "Request for local server key with deprecated key ID (logging to determine usage level for future removal): %s",
+ key_id,
+ )
+
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 3bdb6ec909..8f3865d412 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -126,6 +126,15 @@ class RemoteKey(RestServlet):
self, request: Request, server: str, key_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
if server and key_id:
+ # Matrix 1.6 drops support for passing the key_id, this is incompatible
+ # with earlier versions and is allowed in order to support both.
+ # A warning is issued to help determine when it is safe to drop this.
+ logger.warning(
+ "Request for remote server key with deprecated key ID (logging to determine usage level for future removal): %s / %s",
+ server,
+ key_id,
+ )
+
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
@@ -155,13 +164,13 @@ class RemoteKey(RestServlet):
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
- cached = await self.store.get_server_keys_json(store_queries)
+ cached = await self.store.get_server_keys_json_for_remote(store_queries)
json_results: Set[bytes] = set()
time_now_ms = self.clock.time_msec()
- # Map server_name->key_id->int. Note that the value of the init is unused.
+ # Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
diff --git a/synapse/rest/media/download_resource.py b/synapse/rest/media/download_resource.py
index 8f270cf4cc..3c618ef60a 100644
--- a/synapse/rest/media/download_resource.py
+++ b/synapse/rest/media/download_resource.py
@@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
- self.server_name = hs.hostname
+ self._is_mine_server_name = hs.is_mine_server_name
async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
@@ -59,7 +59,7 @@ class DownloadResource(DirectServeJsonResource):
b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
- if server_name == self.server_name:
+ if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = parse_boolean(request, "allow_remote", default=True)
diff --git a/synapse/rest/media/thumbnail_resource.py b/synapse/rest/media/thumbnail_resource.py
index 4ee2a0dbda..661e604b85 100644
--- a/synapse/rest/media/thumbnail_resource.py
+++ b/synapse/rest/media/thumbnail_resource.py
@@ -59,7 +59,8 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
- self.server_name = hs.hostname
+ self._is_mine_server_name = hs.is_mine_server_name
+ self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
@@ -71,7 +72,7 @@ class ThumbnailResource(DirectServeJsonResource):
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"
- if server_name == self.server_name:
+ if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type
@@ -82,6 +83,14 @@ class ThumbnailResource(DirectServeJsonResource):
)
self.media_repo.mark_recently_accessed(None, media_id)
else:
+ # Don't let users download media from configured domains, even if it
+ # is already downloaded. This is Trust & Safety tooling to make some
+ # media inaccessible to local users.
+ # See `prevent_media_downloads_from` config docs for more info.
+ if server_name in self.prevent_media_downloads_from:
+ respond_404(request)
+ return
+
if self.dynamic_thumbnails:
await self._select_or_generate_remote_thumbnail(
request, server_name, media_id, width, height, method, m_type
|