diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 57cac22252..55ddebb4fe 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -21,17 +21,16 @@ import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.rest.admin._base import (
- admin_patterns,
- assert_requester_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
DevicesRestServlet,
)
-from synapse.rest.admin.event_reports import EventReportsRestServlet
+from synapse.rest.admin.event_reports import (
+ EventReportDetailRestServlet,
+ EventReportsRestServlet,
+)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
@@ -44,19 +43,24 @@ from synapse.rest.admin.rooms import (
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
+from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
from synapse.rest.admin.users import (
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
+ PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
UserAdminServlet,
+ UserMediaRestServlet,
UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
UsersRestServlet,
UsersRestServletV2,
+ UserTokenRestServlet,
WhoisRestServlet,
)
+from synapse.types import RoomStreamToken
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
@@ -76,7 +80,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
@@ -109,7 +113,9 @@ class PurgeHistoryRestServlet(RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- room_token = await self.store.get_topological_token_for_event(event_id)
+ room_token = RoomStreamToken(
+ event.depth, event.internal_metadata.stream_ordering
+ )
token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
@@ -159,9 +165,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
- "/purge_history_status/(?P<purge_id>[^/]+)"
- )
+ PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
def __init__(self, hs):
"""
@@ -212,13 +216,18 @@ def register_servlets(hs, http_server):
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
+ UserMediaRestServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
+ UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
+ UserMediaStatisticsRestServlet(hs).register(http_server)
+ EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
+ PushersRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index db9fea263a..e09234c644 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -22,28 +22,6 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
-def historical_admin_path_patterns(path_regex):
- """Returns the list of patterns for an admin endpoint, including historical ones
-
- This is a backwards-compatibility hack. Previously, the Admin API was exposed at
- various paths under /_matrix/client. This function returns a list of patterns
- matching those paths (as well as the new one), so that existing scripts which rely
- on the endpoints being available there are not broken.
-
- Note that this should only be used for existing endpoints: new ones should just
- register for the /_synapse/admin path.
- """
- return [
- re.compile(prefix + path_regex)
- for prefix in (
- "^/_synapse/admin/v1",
- "^/_matrix/client/api/v1/admin",
- "^/_matrix/client/unstable/admin",
- "^/_matrix/client/r0/admin",
- )
- ]
-
-
def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index a163863322..ffd3aa38f7 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -119,7 +119,7 @@ class DevicesRestServlet(RestServlet):
raise NotFoundError("Unknown user")
devices = await self.device_handler.get_devices_by_user(target_user.to_string())
- return 200, {"devices": devices}
+ return 200, {"devices": devices, "total": len(devices)}
class DeleteDevicesRestServlet(RestServlet):
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 5b8d0594cd..fd482f0e32 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -15,7 +15,7 @@
import logging
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
@@ -86,3 +86,47 @@ class EventReportsRestServlet(RestServlet):
ret["next_token"] = start + len(event_reports)
return 200, ret
+
+
+class EventReportDetailRestServlet(RestServlet):
+ """
+ Get a specific reported event that is known to the homeserver. Results are returned
+ in a dictionary containing report information.
+ The requester must have administrator access in Synapse.
+
+ GET /_synapse/admin/v1/event_reports/<report_id>
+ returns:
+ 200 OK with details report if success otherwise an error.
+
+ Args:
+ The parameter `report_id` is the ID of the event report in the database.
+ Returns:
+ JSON blob of information about the event report
+ """
+
+ PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request, report_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ message = (
+ "The report_id parameter must be a string representing a positive integer."
+ )
+ try:
+ report_id = int(report_id)
+ except ValueError:
+ raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+
+ if report_id < 0:
+ raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
+
+ ret = await self.store.get_event_report(report_id)
+ if not ret:
+ raise NotFoundError("Event report not found")
+
+ return 200, ret
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 0b54ca09f4..d0c86b204a 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -16,10 +16,7 @@ import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from synapse.rest.admin._base import (
- assert_user_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
logger = logging.getLogger(__name__)
@@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+ PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ee75095c0e..c82b4f87d6 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -16,12 +16,12 @@
import logging
-from synapse.api.errors import AuthError
-from synapse.http.servlet import RestServlet, parse_integer
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.rest.admin._base import (
+ admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
logger = logging.getLogger(__name__)
@@ -33,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = (
- historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
# This path kept around for legacy reasons
- historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
@@ -62,9 +62,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns(
- "/user/(?P<user_id>[^/]+)/media/quarantine"
- )
+ PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -89,7 +87,7 @@ class QuarantineMediaByID(RestServlet):
it via this server.
"""
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
@@ -115,7 +113,7 @@ class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
- PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
+ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -133,7 +131,7 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/purge_media_cache")
+ PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
@@ -150,6 +148,80 @@ class PurgeMediaCacheRestServlet(RestServlet):
return 200, ret
+class DeleteMediaByID(RestServlet):
+ """Delete local media by a given ID. Removes it from this server.
+ """
+
+ PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.server_name = hs.hostname
+ self.media_repository = hs.get_media_repository()
+
+ async def on_DELETE(self, request, server_name: str, media_id: str):
+ await assert_requester_is_admin(self.auth, request)
+
+ if self.server_name != server_name:
+ raise SynapseError(400, "Can only delete local media")
+
+ if await self.store.get_local_media(media_id) is None:
+ raise NotFoundError("Unknown media")
+
+ logging.info("Deleting local media by ID: %s", media_id)
+
+ deleted_media, total = await self.media_repository.delete_local_media(media_id)
+ return 200, {"deleted_media": deleted_media, "total": total}
+
+
+class DeleteMediaByDateSize(RestServlet):
+ """Delete local media and local copies of remote media by
+ timestamp and size.
+ """
+
+ PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.server_name = hs.hostname
+ self.media_repository = hs.get_media_repository()
+
+ async def on_POST(self, request, server_name: str):
+ await assert_requester_is_admin(self.auth, request)
+
+ before_ts = parse_integer(request, "before_ts", required=True)
+ size_gt = parse_integer(request, "size_gt", default=0)
+ keep_profiles = parse_boolean(request, "keep_profiles", default=True)
+
+ if before_ts < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter before_ts must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+ if size_gt < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter size_gt must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if self.server_name != server_name:
+ raise SynapseError(400, "Can only delete local media")
+
+ logging.info(
+ "Deleting local media by timestamp: %s, size larger than: %s, keep profile media: %s"
+ % (before_ts, size_gt, keep_profiles)
+ )
+
+ deleted_media, total = await self.media_repository.delete_old_local_media(
+ before_ts, size_gt, keep_profiles
+ )
+ return 200, {"deleted_media": deleted_media, "total": total}
+
+
def register_servlets_for_media_repo(hs, http_server):
"""
Media repo specific APIs.
@@ -159,3 +231,5 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
+ DeleteMediaByID(hs).register(http_server)
+ DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 09726d52d6..25f89e4685 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -29,7 +29,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -44,7 +43,7 @@ class ShutdownRoomRestServlet(RestServlet):
joined to the new room.
"""
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+ PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs):
self.hs = hs
@@ -71,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
class DeleteRoomRestServlet(RestServlet):
- """Delete a room from server. It is a combination and improvement of
- shut down and purge room.
+ """Delete a room from server.
+
+ It is a combination and improvement of shutdown and purge room.
+
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
+
If desired any local aliases will be repointed to a new room
- created by `new_room_user_id` and kicked users will be auto
+ created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
- It will remove all trace of a room from the database.
+
+ If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
@@ -111,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ force_purge = content.get("force_purge", False)
+ if not isinstance(force_purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'force_purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -122,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
# Purge room
if purge:
- await self.pagination_handler.purge_room(room_id)
+ await self.pagination_handler.purge_room(room_id, force=force_purge)
return (200, ret)
@@ -138,7 +149,7 @@ class ListRoomRestServlet(RestServlet):
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
@@ -273,7 +284,7 @@ class JoinRoomAliasServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
async def on_POST(self, request, room_identifier):
@@ -309,7 +320,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- fake_requester = create_requester(target_user)
+ fake_requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
new file mode 100644
index 0000000000..f2490e382d
--- /dev/null
+++ b/synapse/rest/admin/statistics.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
+from synapse.storage.databases.main.stats import UserSortOrder
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class UserMediaStatisticsRestServlet(RestServlet):
+ """
+ Get statistics about uploaded media by users.
+ """
+
+ PATTERNS = admin_patterns("/statistics/users/media$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ order_by = parse_string(
+ request, "order_by", default=UserSortOrder.USER_ID.value
+ )
+ if order_by not in (
+ UserSortOrder.MEDIA_LENGTH.value,
+ UserSortOrder.MEDIA_COUNT.value,
+ UserSortOrder.USER_ID.value,
+ UserSortOrder.DISPLAYNAME.value,
+ ):
+ raise SynapseError(
+ 400,
+ "Unknown value for order_by: %s" % (order_by,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ start = parse_integer(request, "from", default=0)
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ limit = parse_integer(request, "limit", default=100)
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ from_ts = parse_integer(request, "from_ts", default=0)
+ if from_ts < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from_ts must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ until_ts = parse_integer(request, "until_ts")
+ if until_ts is not None:
+ if until_ts < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter until_ts must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+ if until_ts <= from_ts:
+ raise SynapseError(
+ 400,
+ "Query parameter until_ts must be greater than from_ts.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ search_term = parse_string(request, "search_term")
+ if search_term == "":
+ raise SynapseError(
+ 400,
+ "Query parameter search_term cannot be an empty string.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ direction = parse_string(request, "dir", default="f")
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ users_media, total = await self.store.get_users_media_usage_paginate(
+ start, limit, from_ts, until_ts, order_by, direction, search_term
+ )
+ ret = {"users": users_media, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(users_media)
+
+ return 200, ret
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 20dc1d0e05..b0ff5e1ead 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,6 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -27,25 +28,40 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
-from synapse.types import UserID
+from synapse.rest.client.v2_alpha._base import client_patterns
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+_GET_PUSHERS_ALLOWED_KEYS = {
+ "app_display_name",
+ "app_id",
+ "data",
+ "device_display_name",
+ "kind",
+ "lang",
+ "profile_tag",
+ "pushkey",
+}
+
class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -82,7 +98,7 @@ class UsersRestServletV2(RestServlet):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request):
await assert_requester_is_admin(self.auth, request)
@@ -135,7 +151,7 @@ class UserRestServletV2(RestServlet):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
@@ -322,7 +338,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = historical_admin_path_patterns("/register")
+ PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
@@ -399,6 +415,7 @@ class UserRegisterServlet(RestServlet):
admin = body.get("admin", None)
user_type = body.get("user_type", None)
+ displayname = body.get("displayname", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
@@ -435,6 +452,7 @@ class UserRegisterServlet(RestServlet):
password_hash=password_hash,
admin=bool(admin),
user_type=user_type,
+ default_display_name=displayname,
by_admin=True,
)
@@ -443,12 +461,19 @@ class UserRegisterServlet(RestServlet):
class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+ path_regex = "/whois/(?P<user_id>[^/]*)$"
+ PATTERNS = (
+ admin_patterns(path_regex)
+ +
+ # URL for spec reason
+ # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
+ client_patterns("/admin" + path_regex, v1=True)
+ )
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
+ self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -461,13 +486,13 @@ class WhoisRestServlet(RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
- ret = await self.handlers.admin_handler.get_whois(target_user)
+ ret = await self.admin_handler.get_whois(target_user)
return 200, ret
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -498,7 +523,7 @@ class DeactivateAccountRestServlet(RestServlet):
class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+ PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs):
"""
@@ -541,9 +566,7 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
+ PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -585,13 +608,12 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
async def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
@@ -612,7 +634,7 @@ class SearchUsersRestServlet(RestServlet):
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
- ret = await self.handlers.store.search_users(term)
+ ret = await self.store.search_users(term)
return 200, ret
@@ -703,9 +725,163 @@ class UserMembershipRestServlet(RestServlet):
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users")
+ user = await self.store.get_user_by_id(user_id)
+ if user is None:
+ raise NotFoundError("Unknown user")
+
room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
+ ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
+ return 200, ret
+
+
+class PushersRestServlet(RestServlet):
+ """
+ Gets information about all pushers for a specific `user_id`.
+
+ Example:
+ http://localhost:8008/_synapse/admin/v1/users/
+ @user:server/pushers
+
+ Returns:
+ pushers: Dictionary containing pushers information.
+ total: Number of pushers in dictonary `pushers`.
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
+
+ def __init__(self, hs):
+ self.is_mine = hs.is_mine
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only lookup local users")
+
+ if not await self.store.get_user_by_id(user_id):
raise NotFoundError("User not found")
- ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
+ pushers = await self.store.get_pushers_by_user_id(user_id)
+
+ filtered_pushers = [
+ {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
+ for p in pushers
+ ]
+
+ return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
+
+
+class UserMediaRestServlet(RestServlet):
+ """
+ Gets information about all uploaded local media for a specific `user_id`.
+
+ Example:
+ http://localhost:8008/_synapse/admin/v1/users/
+ @user:server/media
+
+ Args:
+ The parameters `from` and `limit` are required for pagination.
+ By default, a `limit` of 100 is used.
+ Returns:
+ A list of media and an integer representing the total number of
+ media that exist given for this user
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
+
+ def __init__(self, hs):
+ self.is_mine = hs.is_mine
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.is_mine(UserID.from_string(user_id)):
+ raise SynapseError(400, "Can only lookup local users")
+
+ user = await self.store.get_user_by_id(user_id)
+ if user is None:
+ raise NotFoundError("Unknown user")
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ media, total = await self.store.get_local_media_by_user_paginate(
+ start, limit, user_id
+ )
+
+ ret = {"media": media, "total": total}
+ if (start + limit) < total:
+ ret["next_token"] = start + len(media)
+
return 200, ret
+
+
+class UserTokenRestServlet(RestServlet):
+ """An admin API for logging in as a user.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/login
+ {}
+
+ 200 OK
+ {
+ "access_token": "<some_token>"
+ }
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ auth_user = requester.user
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be logged in as")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+
+ valid_until_ms = body.get("valid_until_ms")
+ if valid_until_ms and not isinstance(valid_until_ms, int):
+ raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+
+ if auth_user.to_string() == user_id:
+ raise SynapseError(400, "Cannot use admin API to login as self")
+
+ token = await self.auth_handler.get_access_token_for_user_id(
+ user_id=auth_user.to_string(),
+ device_id=None,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=user_id,
+ )
+
+ return 200, {"access_token": token}
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index faabeeb91c..e5af26b176 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -42,14 +42,13 @@ class ClientDirectoryServer(RestServlet):
def __init__(self, hs):
super().__init__()
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
+ self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
- dir_handler = self.handlers.directory_handler
- res = await dir_handler.get_association(room_alias)
+ res = await self.directory_handler.get_association(room_alias)
return 200, res
@@ -79,19 +78,19 @@ class ClientDirectoryServer(RestServlet):
requester = await self.auth.get_user_by_req(request)
- await self.handlers.directory_handler.create_association(
+ await self.directory_handler.create_association(
requester, room_alias, room_id, servers
)
return 200, {}
async def on_DELETE(self, request, room_alias):
- dir_handler = self.handlers.directory_handler
-
try:
service = self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
- await dir_handler.delete_appservice_association(service, room_alias)
+ await self.directory_handler.delete_appservice_association(
+ service, room_alias
+ )
logger.info(
"Application service at %s deleted alias %s",
service.url,
@@ -107,7 +106,7 @@ class ClientDirectoryServer(RestServlet):
room_alias = RoomAlias.from_string(room_alias)
- await dir_handler.delete_association(requester, room_alias)
+ await self.directory_handler.delete_association(requester, room_alias)
logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
@@ -122,7 +121,7 @@ class ClientDirectoryListServer(RestServlet):
def __init__(self, hs):
super().__init__()
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
+ self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
@@ -138,7 +137,7 @@ class ClientDirectoryListServer(RestServlet):
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- await self.handlers.directory_handler.edit_published_room_list(
+ await self.directory_handler.edit_published_room_list(
requester, room_id, visibility
)
@@ -147,7 +146,7 @@ class ClientDirectoryListServer(RestServlet):
async def on_DELETE(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
- await self.handlers.directory_handler.edit_published_room_list(
+ await self.directory_handler.edit_published_room_list(
requester, room_id, "private"
)
@@ -162,7 +161,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def __init__(self, hs):
super().__init__()
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
+ self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id):
@@ -180,7 +179,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
403, "Only appservices can edit the appservice published room list"
)
- await self.handlers.directory_handler.edit_published_appservice_room_list(
+ await self.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility
)
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 1ecb77aa26..6de4078290 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -67,9 +67,6 @@ class EventStreamRestServlet(RestServlet):
return 200, chunk
- def on_OPTIONS(self, request):
- return 200, {}
-
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 3d1693d7ac..d7ae148214 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.handlers.auth import (
- convert_client_dict_legacy_fields_to_identifier,
- login_id_phone_to_thirdparty,
-)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
-from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -67,7 +62,6 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -79,11 +73,6 @@ class LoginRestServlet(RestServlet):
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
- self._failed_attempts_ratelimiter = Ratelimiter(
- clock=hs.get_clock(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- )
def on_GET(self, request: SynapseRequest):
flows = []
@@ -111,10 +100,9 @@ class LoginRestServlet(RestServlet):
({"type": t} for t in self.auth_handler.get_supported_login_types())
)
- return 200, {"flows": flows}
+ flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
- def on_OPTIONS(self, request: SynapseRequest):
- return 200, {}
+ return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
@@ -142,27 +130,31 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- def _get_qualified_user_id(self, identifier):
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- return identifier["user"]
- else:
- return UserID(identifier["user"], self.hs.hostname).to_string()
-
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
- logger.info(
- "Got appservice login request with identifier: %r",
- login_submission.get("identifier"),
- )
+ identifier = login_submission.get("identifier")
+ logger.info("Got appservice login request with identifier: %r", identifier)
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
- qualified_user_id = self._get_qualified_user_id(identifier)
+ if not isinstance(identifier, dict):
+ raise SynapseError(
+ 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
+ )
+
+ # this login flow only supports identifiers of type "m.id.user".
+ if identifier.get("type") != "m.id.user":
+ raise SynapseError(
+ 400, "Unknown login identifier type", Codes.INVALID_PARAM
+ )
+
+ user = identifier.get("user")
+ if not isinstance(user, str):
+ raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
+
+ if user.startswith("@"):
+ qualified_user_id = user
+ else:
+ qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
@@ -188,91 +180,9 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
-
- # convert phone type identifiers to generic threepids
- if identifier["type"] == "m.id.phone":
- identifier = login_id_phone_to_thirdparty(identifier)
-
- # convert threepid identifiers to user IDs
- if identifier["type"] == "m.id.thirdparty":
- address = identifier.get("address")
- medium = identifier.get("medium")
-
- if medium is None or address is None:
- raise SynapseError(400, "Invalid thirdparty identifier")
-
- # For emails, canonicalise the address.
- # We store all email addresses canonicalised in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- if medium == "email":
- try:
- address = canonicalise_email(address)
- except ValueError as e:
- raise SynapseError(400, str(e))
-
- # We also apply account rate limiting using the 3PID as a key, as
- # otherwise using 3PID bypasses the ratelimiting based on user ID.
- self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
-
- # Check for login providers that support 3pid login types
- (
- canonical_user_id,
- callback_3pid,
- ) = await self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
- if canonical_user_id:
- # Authentication through password provider and 3pid succeeded
-
- result = await self._complete_login(
- canonical_user_id, login_submission, callback_3pid
- )
- return result
-
- # No password providers were able to handle this 3pid
- # Check local store
- user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- medium, address
- )
- if not user_id:
- logger.warning(
- "unknown 3pid identifier medium %s, address %r", medium, address
- )
- # We mark that we've failed to log in here, as
- # `check_password_provider_3pid` might have returned `None` due
- # to an incorrect password, rather than the account not
- # existing.
- #
- # If it returned None but the 3PID was bound then we won't hit
- # this code path, which is fine as then the per-user ratelimit
- # will kick in below.
- self._failed_attempts_ratelimiter.can_do_action((medium, address))
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- identifier = {"type": "m.id.user", "user": user_id}
-
- # by this point, the identifier should be an m.id.user: if it's anything
- # else, we haven't understood it.
- qualified_user_id = self._get_qualified_user_id(identifier)
-
- # Check if we've hit the failed ratelimit (but don't update it)
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), update=False
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ login_submission, ratelimit=True
)
-
- try:
- canonical_user_id, callback = await self.auth_handler.validate_login(
- identifier["user"], login_submission
- )
- except LoginError:
- # The user has failed to log in, so we need to update the rate
- # limiter. Using `can_do_action` avoids us raising a ratelimit
- # exception and masking the LoginError. The actual ratelimiting
- # should have happened above.
- self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
- raise
-
result = await self._complete_login(
canonical_user_id, login_submission, callback
)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index f792b50cdc..ad8cea49c6 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -30,9 +30,6 @@ class LogoutRestServlet(RestServlet):
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- def on_OPTIONS(self, request):
- return 200, {}
-
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_expired=True)
@@ -58,9 +55,6 @@ class LogoutAllRestServlet(RestServlet):
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
- def on_OPTIONS(self, request):
- return 200, {}
-
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 79d8e3057f..23a529f8e3 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -86,9 +86,6 @@ class PresenceStatusRestServlet(RestServlet):
return 200, {}
- def on_OPTIONS(self, request):
- return 200, {}
-
def register_servlets(hs, http_server):
PresenceStatusRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index b686cd671f..85a66458c5 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -59,15 +59,14 @@ class ProfileDisplaynameRestServlet(RestServlet):
try:
new_name = content["displayname"]
except Exception:
- return 400, "Unable to parse name"
+ raise SynapseError(
+ code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
+ )
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
return 200, {}
- def on_OPTIONS(self, request, user_id):
- return 200, {}
-
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -116,9 +115,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, {}
- def on_OPTIONS(self, request, user_id):
- return 200, {}
-
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index f9eecb7cf5..241e535917 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -155,9 +155,6 @@ class PushRuleRestServlet(RestServlet):
else:
raise UnrecognizedRequestError()
- def on_OPTIONS(self, request, path):
- return 200, {}
-
def notify_user(self, user_id):
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 28dabf1c7a..8fe83f321a 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -60,9 +60,6 @@ class PushersRestServlet(RestServlet):
return 200, {"pushers": filtered_pushers}
- def on_OPTIONS(self, _):
- return 200, {}
-
class PushersSetRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers/set$", v1=True)
@@ -140,9 +137,6 @@ class PushersSetRestServlet(RestServlet):
return 200, {}
- def on_OPTIONS(self, _):
- return 200, {}
-
class PushersRemoveRestServlet(RestServlet):
"""
@@ -182,9 +176,6 @@ class PushersRemoveRestServlet(RestServlet):
)
return None
- def on_OPTIONS(self, _):
- return 200, {}
-
def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index b63389e5fe..93c06afe27 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@
import logging
import re
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID,
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
@@ -72,20 +71,6 @@ class RoomCreateRestServlet(TransactionRestServlet):
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
- # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
- http_server.register_paths(
- "OPTIONS",
- client_patterns("/rooms(?:/.*)?$", v1=True),
- self.on_OPTIONS,
- self.__class__.__name__,
- )
- # define CORS for /createRoom[/txnid]
- http_server.register_paths(
- "OPTIONS",
- client_patterns("/createRoom(?:/.*)?$", v1=True),
- self.on_OPTIONS,
- self.__class__.__name__,
- )
def on_PUT(self, request, txn_id):
set_tag("txn_id", txn_id)
@@ -104,15 +89,11 @@ class RoomCreateRestServlet(TransactionRestServlet):
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
- def on_OPTIONS(self, request):
- return 200, {}
-
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super().__init__(hs)
- self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -798,7 +779,6 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super().__init__(hs)
- self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
@@ -903,7 +883,7 @@ class RoomAliasListServlet(RestServlet):
def __init__(self, hs: "synapse.server.HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.directory_handler = hs.get_handlers().directory_handler
+ self.directory_handler = hs.get_directory_handler()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request)
@@ -920,7 +900,7 @@ class SearchRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
- self.handlers = hs.get_handlers()
+ self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
@@ -929,9 +909,7 @@ class SearchRestServlet(RestServlet):
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
- results = await self.handlers.search_handler.search(
- requester.user, content, batch
- )
+ results = await self.search_handler.search(requester.user, content, batch)
return 200, results
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index b8d491ca5c..d07ca2c47c 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -69,9 +69,6 @@ class VoipRestServlet(RestServlet):
},
)
- def on_OPTIONS(self, request):
- return 200, {}
-
def register_servlets(hs, http_server):
VoipRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 86d3d86fad..e0feebea94 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -38,6 +38,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -56,7 +57,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.hs = hs
self.datastore = hs.get_datastore()
self.config = hs.config
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self.mailer = Mailer(
@@ -114,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -143,6 +144,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
+ threepid_send_requests.labels(type="email", reason="password_reset").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -268,9 +273,6 @@ class PasswordRestServlet(RestServlet):
return 200, {}
- def on_OPTIONS(self, _):
- return 200, {}
-
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -327,7 +329,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.config = hs.config
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
@@ -385,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -414,6 +416,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
+ threepid_send_requests.labels(type="email", reason="add_threepid").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -424,7 +430,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
body = parse_json_object_from_request(request)
@@ -460,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -484,6 +490,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
next_link,
)
+ threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -574,7 +584,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
if not self.config.account_threepid_delegate_msisdn:
@@ -604,7 +614,7 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
@@ -660,7 +670,7 @@ class ThreepidAddRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@@ -711,7 +721,7 @@ class ThreepidBindRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
@@ -740,7 +750,7 @@ class ThreepidUnbindRestServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 5fbfae5991..fab077747f 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -176,9 +176,6 @@ class AuthRestServlet(RestServlet):
respond_with_html(request, 200, html)
return None
- def on_OPTIONS(self, _):
- return 200, {}
-
def register_servlets(hs, http_server):
AuthRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 7e174de692..af117cb27c 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +22,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from ._base import client_patterns, interactive_auth_handler
@@ -151,7 +153,139 @@ class DeviceRestServlet(RestServlet):
return 200, {}
+class DehydratedDeviceServlet(RestServlet):
+ """Retrieve or store a dehydrated device.
+
+ GET /org.matrix.msc2697.v2/dehydrated_device
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id",
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device"
+ }
+ }
+
+ PUT /org.matrix.msc2697/dehydrated_device
+ Content-Type: application/json
+
+ {
+ "device_data": {
+ "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+ "account": "dehydrated_device"
+ }
+ }
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id"
+ }
+
+ """
+
+ PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device", releases=())
+
+ def __init__(self, hs):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ async def on_GET(self, request: SynapseRequest):
+ requester = await self.auth.get_user_by_req(request)
+ dehydrated_device = await self.device_handler.get_dehydrated_device(
+ requester.user.to_string()
+ )
+ if dehydrated_device is not None:
+ (device_id, device_data) = dehydrated_device
+ result = {"device_id": device_id, "device_data": device_data}
+ return (200, result)
+ else:
+ raise errors.NotFoundError("No dehydrated device available")
+
+ async def on_PUT(self, request: SynapseRequest):
+ submission = parse_json_object_from_request(request)
+ requester = await self.auth.get_user_by_req(request)
+
+ if "device_data" not in submission:
+ raise errors.SynapseError(
+ 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+ )
+ elif not isinstance(submission["device_data"], dict):
+ raise errors.SynapseError(
+ 400,
+ "device_data must be an object",
+ errcode=errors.Codes.INVALID_PARAM,
+ )
+
+ device_id = await self.device_handler.store_dehydrated_device(
+ requester.user.to_string(),
+ submission["device_data"],
+ submission.get("initial_device_display_name", None),
+ )
+ return 200, {"device_id": device_id}
+
+
+class ClaimDehydratedDeviceServlet(RestServlet):
+ """Claim a dehydrated device.
+
+ POST /org.matrix.msc2697.v2/dehydrated_device/claim
+ Content-Type: application/json
+
+ {
+ "device_id": "dehydrated_device_id"
+ }
+
+ HTTP/1.1 200 OK
+ Content-Type: application/json
+
+ {
+ "success": true,
+ }
+
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc2697.v2/dehydrated_device/claim", releases=()
+ )
+
+ def __init__(self, hs):
+ super().__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.device_handler = hs.get_device_handler()
+
+ async def on_POST(self, request: SynapseRequest):
+ requester = await self.auth.get_user_by_req(request)
+
+ submission = parse_json_object_from_request(request)
+
+ if "device_id" not in submission:
+ raise errors.SynapseError(
+ 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+ )
+ elif not isinstance(submission["device_id"], str):
+ raise errors.SynapseError(
+ 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+ )
+
+ result = await self.device_handler.rehydrate_device(
+ requester.user.to_string(),
+ self.auth.get_access_token_from_request(request),
+ submission["device_id"],
+ )
+
+ return (200, result)
+
+
def register_servlets(hs, http_server):
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
+ DehydratedDeviceServlet(hs).register(http_server)
+ ClaimDehydratedDeviceServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 55c4606569..b91996c738 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -67,6 +68,7 @@ class KeyUploadServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
async def on_POST(self, request, device_id):
@@ -75,23 +77,28 @@ class KeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
if device_id is not None:
- # passing the device_id here is deprecated; however, we allow it
- # for now for compatibility with older clients.
+ # Providing the device_id should only be done for setting keys
+ # for dehydrated devices; however, we allow it for any device for
+ # compatibility with older clients.
if requester.device_id is not None and device_id != requester.device_id:
- set_tag("error", True)
- log_kv(
- {
- "message": "Client uploading keys for a different device",
- "logged_in_id": requester.device_id,
- "key_being_uploaded": device_id,
- }
- )
- logger.warning(
- "Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id,
- device_id,
+ dehydrated_device = await self.device_handler.get_dehydrated_device(
+ user_id
)
+ if dehydrated_device is not None and device_id != dehydrated_device[0]:
+ set_tag("error", True)
+ log_kv(
+ {
+ "message": "Client uploading keys for a different device",
+ "logged_in_id": requester.device_id,
+ "key_being_uploaded": device_id,
+ }
+ )
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
else:
device_id = requester.device_id
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ec8ef9bf88..5374d2c1b6 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -78,7 +79,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
"""
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.config = hs.config
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
@@ -134,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -163,6 +164,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
+ threepid_send_requests.labels(type="email", reason="register").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -176,7 +181,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
"""
super().__init__()
self.hs = hs
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
body = parse_json_object_from_request(request)
@@ -209,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
@@ -234,6 +239,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
next_link,
)
+ threepid_send_requests.labels(type="msisdn", reason="register").observe(
+ send_attempt
+ )
+
return 200, ret
@@ -370,7 +379,7 @@ class RegisterRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
- self.identity_handler = hs.get_handlers().identity_handler
+ self.identity_handler = hs.get_identity_handler()
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -644,9 +653,6 @@ class RegisterRestServlet(RestServlet):
return 200, return_dict
- def on_OPTIONS(self, _):
- return 200, {}
-
async def _do_appservice_registration(self, username, as_token, body):
user_id = await self.registration_handler.appservice_register(
username, as_token
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 6779df952f..8e52e4cca4 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
)
with context:
sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
sync_config,
since_token=since_token,
timeout=timeout,
@@ -236,6 +237,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
+ "org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
"next_batch": await sync_result.next_batch.to_string(self.store),
}
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index c16280f668..d8e8e48c1c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
- self.clock = hs.clock
+ self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 6568e61829..67aa993f19 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -213,6 +213,12 @@ async def respond_with_responder(
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
if not responder:
respond_404(request)
return
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 7447eeaebe..9e079f672f 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -69,6 +69,23 @@ class MediaFilePaths:
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
+ def local_media_thumbnail_dir(self, media_id: str) -> str:
+ """
+ Retrieve the local store path of thumbnails of a given media_id
+
+ Args:
+ media_id: The media ID to query.
+ Returns:
+ Path of local_thumbnails from media_id
+ """
+ return os.path.join(
+ self.base_path,
+ "local_thumbnails",
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
+ )
+
def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index e1192b47cd..9cac74ebd8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,7 +18,7 @@ import errno
import logging
import os
import shutil
-from typing import IO, Dict, Optional, Tuple
+from typing import IO, Dict, List, Optional, Tuple
import twisted.internet.error
import twisted.web.http
@@ -305,15 +305,12 @@ class MediaRepository:
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
- if media_info:
- file_id = media_info["filesystem_id"]
- else:
- file_id = random_string(24)
-
- file_info = FileInfo(server_name, file_id)
# If we have an entry in the DB, try and look for it
if media_info:
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
@@ -324,14 +321,34 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
- media_info = await self._download_remote_file(server_name, media_id, file_id)
+ try:
+ media_info = await self._download_remote_file(server_name, media_id,)
+ except SynapseError:
+ raise
+ except Exception as e:
+ # An exception may be because we downloaded media in another
+ # process, so let's check if we magically have the media.
+ media_info = await self.store.get_cached_remote_media(server_name, media_id)
+ if not media_info:
+ raise e
+
+ file_id = media_info["filesystem_id"]
+ file_info = FileInfo(server_name, file_id)
+
+ # We generate thumbnails even if another process downloaded the media
+ # as a) it's conceivable that the other download request dies before it
+ # generates thumbnails, but mainly b) we want to be sure the thumbnails
+ # have finished being generated before responding to the client,
+ # otherwise they'll request thumbnails and get a 404 if they're not
+ # ready yet.
+ await self._generate_thumbnails(
+ server_name, media_id, file_id, media_info["media_type"]
+ )
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(
- self, server_name: str, media_id: str, file_id: str
- ) -> dict:
+ async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -346,6 +363,8 @@ class MediaRepository:
The media info of the file.
"""
+ file_id = random_string(24)
+
file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
@@ -401,22 +420,32 @@ class MediaRepository:
await finish()
- media_type = headers[b"Content-Type"][0].decode("ascii")
- upload_name = get_filename_from_headers(headers)
- time_now_ms = self.clock.time_msec()
+ media_type = headers[b"Content-Type"][0].decode("ascii")
+ upload_name = get_filename_from_headers(headers)
+ time_now_ms = self.clock.time_msec()
+
+ # Multiple remote media download requests can race (when using
+ # multiple media repos), so this may throw a violation constraint
+ # exception. If it does we'll delete the newly downloaded file from
+ # disk (as we're in the ctx manager).
+ #
+ # However: we've already called `finish()` so we may have also
+ # written to the storage providers. This is preferable to the
+ # alternative where we call `finish()` *after* this, where we could
+ # end up having an entry in the DB but fail to write the files to
+ # the storage providers.
+ await self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
logger.info("Stored remote media in file %r", fname)
- await self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
-
media_info = {
"media_type": media_type,
"media_length": length,
@@ -425,8 +454,6 @@ class MediaRepository:
"filesystem_id": file_id,
}
- await self._generate_thumbnails(server_name, media_id, file_id, media_type)
-
return media_info
def _get_thumbnail_requirements(self, media_type):
@@ -692,42 +719,60 @@ class MediaRepository:
if not t_byte_source:
continue
- try:
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- thumbnail=True,
- thumbnail_width=t_width,
- thumbnail_height=t_height,
- thumbnail_method=t_method,
- thumbnail_type=t_type,
- url_cache=url_cache,
- )
-
- output_path = await self.media_storage.store_file(
- t_byte_source, file_info
- )
- finally:
- t_byte_source.close()
-
- t_len = os.path.getsize(output_path)
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
- # Write to database
- if server_name:
- await self.store.store_remote_media_thumbnail(
- server_name,
- media_id,
- file_id,
- t_width,
- t_height,
- t_type,
- t_method,
- t_len,
- )
- else:
- await self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
+ await self.media_storage.write_to_file(t_byte_source, f)
+ await finish()
+ finally:
+ t_byte_source.close()
+
+ t_len = os.path.getsize(fname)
+
+ # Write to database
+ if server_name:
+ # Multiple remote media download requests can race (when
+ # using multiple media repos), so this may throw a violation
+ # constraint exception. If it does we'll delete the newly
+ # generated thumbnail from disk (as we're in the ctx
+ # manager).
+ #
+ # However: we've already called `finish()` so we may have
+ # also written to the storage providers. This is preferable
+ # to the alternative where we call `finish()` *after* this,
+ # where we could end up having an entry in the DB but fail
+ # to write the files to the storage providers.
+ try:
+ await self.store.store_remote_media_thumbnail(
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
+ )
+ except Exception as e:
+ thumbnail_exists = await self.store.get_remote_media_thumbnail(
+ server_name, media_id, t_width, t_height, t_type,
+ )
+ if not thumbnail_exists:
+ raise e
+ else:
+ await self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
return {"width": m_width, "height": m_height}
@@ -767,6 +812,76 @@ class MediaRepository:
return {"deleted": deleted}
+ async def delete_local_media(self, media_id: str) -> Tuple[List[str], int]:
+ """
+ Delete the given local or remote media ID from this server
+
+ Args:
+ media_id: The media ID to delete.
+ Returns:
+ A tuple of (list of deleted media IDs, total deleted media IDs).
+ """
+ return await self._remove_local_media_from_disk([media_id])
+
+ async def delete_old_local_media(
+ self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True,
+ ) -> Tuple[List[str], int]:
+ """
+ Delete local or remote media from this server by size and timestamp. Removes
+ media files, any thumbnails and cached URLs.
+
+ Args:
+ before_ts: Unix timestamp in ms.
+ Files that were last used before this timestamp will be deleted
+ size_gt: Size of the media in bytes. Files that are larger will be deleted
+ keep_profiles: Switch to delete also files that are still used in image data
+ (e.g user profile, room avatar)
+ If false these files will be deleted
+ Returns:
+ A tuple of (list of deleted media IDs, total deleted media IDs).
+ """
+ old_media = await self.store.get_local_media_before(
+ before_ts, size_gt, keep_profiles,
+ )
+ return await self._remove_local_media_from_disk(old_media)
+
+ async def _remove_local_media_from_disk(
+ self, media_ids: List[str]
+ ) -> Tuple[List[str], int]:
+ """
+ Delete local or remote media from this server. Removes media files,
+ any thumbnails and cached URLs.
+
+ Args:
+ media_ids: List of media_id to delete
+ Returns:
+ A tuple of (list of deleted media IDs, total deleted media IDs).
+ """
+ removed_media = []
+ for media_id in media_ids:
+ logger.info("Deleting media with ID '%s'", media_id)
+ full_path = self.filepaths.local_media_filepath(media_id)
+ try:
+ os.remove(full_path)
+ except OSError as e:
+ logger.warning("Failed to remove file: %r: %s", full_path, e)
+ if e.errno == errno.ENOENT:
+ pass
+ else:
+ continue
+
+ thumbnail_dir = self.filepaths.local_media_thumbnail_dir(media_id)
+ shutil.rmtree(thumbnail_dir, ignore_errors=True)
+
+ await self.store.delete_remote_media(self.server_name, media_id)
+
+ await self.store.delete_url_cache((media_id,))
+ await self.store.delete_url_cache_media((media_id,))
+
+ removed_media.append(media_id)
+
+ return removed_media, len(removed_media)
+
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a9586fb0b7..268e0c8f50 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -52,6 +52,7 @@ class MediaStorage:
storage_providers: Sequence["StorageProviderWrapper"],
):
self.hs = hs
+ self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@@ -70,13 +71,16 @@ class MediaStorage:
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- await defer_to_thread(
- self.hs.get_reactor(), _write_file_synchronously, source, f
- )
+ await self.write_to_file(source, f)
await finish_cb()
return fname
+ async def write_to_file(self, source: IO, output: IO):
+ """Asynchronously write the `source` to `output`.
+ """
+ await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
+
@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
@@ -112,14 +116,20 @@ class MediaStorage:
finished_called = [False]
- async def finish():
- for provider in self.storage_providers:
- await provider.store_file(path, file_info)
-
- finished_called[0] = True
-
try:
with open(fname, "wb") as f:
+
+ async def finish():
+ # Ensure that all writes have been flushed and close the
+ # file.
+ f.flush()
+ f.close()
+
+ for provider in self.storage_providers:
+ await provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
yield f, fname, finish
except Exception:
try:
@@ -210,7 +220,7 @@ class MediaStorage:
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.hs.get_reactor()
+ open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
|