diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index c499afd4be..465e06772b 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -69,6 +69,7 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
from synapse.rest.admin.username_available import UsernameAvailableRestServlet
from synapse.rest.admin.users import (
+ AccountDataRestServlet,
AccountValidityRenewServlet,
DeactivateAccountRestServlet,
PushersRestServlet,
@@ -108,7 +109,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = admin_patterns(
- "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
+ "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
)
def __init__(self, hs: "HomeServer"):
@@ -195,7 +196,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
- PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
+ PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.pagination_handler = hs.get_pagination_handler()
@@ -255,6 +256,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
+ AccountDataRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/background_updates.py b/synapse/rest/admin/background_updates.py
index 479672d4d5..6ec00ce0b9 100644
--- a/synapse/rest/admin/background_updates.py
+++ b/synapse/rest/admin/background_updates.py
@@ -22,7 +22,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
-from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
return HTTPStatus.OK, {"enabled": enabled}
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
@@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
self._data_stores = hs.get_datastores()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
# We need to check that all configured databases have updates enabled.
# (They *should* all be in sync.)
@@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
class BackgroundUpdateStartJobRestServlet(RestServlet):
"""Allows to start specific background updates"""
- PATTERNS = admin_patterns("/background_updates/start_job")
+ PATTERNS = admin_patterns("/background_updates/start_job$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastore()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_requester_is_admin(self._auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["job_name"])
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 2e5a6600d3..d9905ff560 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
- self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str, device_id: str
@@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
device = await self.device_handler.get_device(
target_user.to_string(), device_id
)
+ if device is None:
+ raise NotFoundError("No device found")
return HTTPStatus.OK, device
async def on_DELETE(
@@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs: "HomeServer"):
- """
- Args:
- hs: server
- """
- self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
@@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
+ self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, user_id: str
@@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
u = await self.store.get_user_by_id(target_user.to_string())
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 5ee8b11110..38477f8ead 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 744687be35..50d88c9109 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
200 OK with details of a destination if success otherwise an error.
"""
- PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
+ PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index a27110388f..cd697e180e 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups"""
- PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
+ PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 9e23e2d8fc..7236e4027f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,7 +17,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = [
- *admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
+ *admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
# This path kept around for legacy reasons
- *admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
+ *admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
]
def __init__(self, hs: "HomeServer"):
@@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
- PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
+ PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
- "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+ "/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
"""
PATTERNS = admin_patterns(
- "/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+ "/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
def __init__(self, hs: "HomeServer"):
@@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
logging.info(
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
@@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined."""
- PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+ PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
logging.info("Protecting local media by ID: %s", media_id)
@@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
class UnprotectMediaByID(RestServlet):
"""Unprotect local media from being quarantined."""
- PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
+ PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
logging.info("Unprotecting local media by ID: %s", media_id)
@@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room."""
- PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
+ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- is_admin = await self.auth.is_server_admin(requester.user)
- if not is_admin:
- raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
+ await assert_requester_is_admin(self.auth, request)
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
@@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
class DeleteMediaByID(RestServlet):
"""Delete local media by a given ID. Removes it from this server."""
- PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
+ PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size.
"""
- PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
+ PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
media that exist given for this user
"""
- PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
- allowed_values=(
- MediaSortOrder.MEDIA_ID.value,
- MediaSortOrder.UPLOAD_NAME.value,
- MediaSortOrder.CREATED_TS.value,
- MediaSortOrder.LAST_ACCESS_TS.value,
- MediaSortOrder.MEDIA_LENGTH.value,
- MediaSortOrder.MEDIA_TYPE.value,
- MediaSortOrder.QUARANTINED_BY.value,
- MediaSortOrder.SAFE_FROM_QUARANTINE.value,
- ),
+ allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
@@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
request,
"order_by",
default=MediaSortOrder.CREATED_TS.value,
- allowed_values=(
- MediaSortOrder.MEDIA_ID.value,
- MediaSortOrder.UPLOAD_NAME.value,
- MediaSortOrder.CREATED_TS.value,
- MediaSortOrder.LAST_ACCESS_TS.value,
- MediaSortOrder.MEDIA_LENGTH.value,
- MediaSortOrder.MEDIA_TYPE.value,
- MediaSortOrder.QUARANTINED_BY.value,
- MediaSortOrder.SAFE_FROM_QUARANTINE.value,
- ),
+ allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py
index 891b98c088..04948b6408 100644
--- a/synapse/rest/admin/registration_tokens.py
+++ b/synapse/rest/admin/registration_tokens.py
@@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/new$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 829e86675a..6030373ebc 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
If 'purge' is true, it will remove all traces of a room from the database.
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
"""Get the status of the delete room background task."""
- PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
+ PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
@@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
# Extract query parameters
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
- order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
- if order_by not in (
- RoomSortOrder.ALPHABETICAL.value,
- RoomSortOrder.SIZE.value,
- RoomSortOrder.NAME.value,
- RoomSortOrder.CANONICAL_ALIAS.value,
- RoomSortOrder.JOINED_MEMBERS.value,
- RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
- RoomSortOrder.VERSION.value,
- RoomSortOrder.CREATOR.value,
- RoomSortOrder.ENCRYPTION.value,
- RoomSortOrder.FEDERATABLE.value,
- RoomSortOrder.PUBLIC.value,
- RoomSortOrder.JOIN_RULES.value,
- RoomSortOrder.GUEST_ACCESS.value,
- RoomSortOrder.HISTORY_VISIBILITY.value,
- RoomSortOrder.STATE_EVENTS.value,
- ):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown value for order_by: %s" % (order_by,),
- errcode=Codes.INVALID_PARAM,
- )
+ order_by = parse_string(
+ request,
+ "order_by",
+ default=RoomSortOrder.NAME.value,
+ allowed_values=[sort_order.value for sort_order in RoomSortOrder],
+ )
search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "":
@@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
TODO: Add on_POST to allow room creation without joining the room
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
@@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
Get members list of a room.
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
Get full state within a room.
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
@@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
- PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+ PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
+ self.is_mine = hs.is_mine
async def on_POST(
self, request: SynapseRequest, room_identifier: str
@@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
assert_params_in_dict(content, ["user_id"])
target_user = UserID.from_string(content["user_id"])
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"This endpoint can only be used with local users",
@@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
}
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_DELETE(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
room_id, _ = await self.resolve_room_id(room_identifier)
@@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_requester_is_admin(self.auth, request)
room_id, _ = await self.resolve_room_id(room_identifier)
@@ -771,13 +745,19 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results["events_before"] = await self._event_serializer.serialize_events(
- results["events_before"], time_now
+ results["events_before"],
+ time_now,
+ bundle_aggregations=True,
)
results["event"] = await self._event_serializer.serialize_event(
- results["event"], time_now
+ results["event"],
+ time_now,
+ bundle_aggregations=True,
)
results["events_after"] = await self._event_serializer.serialize_events(
- results["events_after"], time_now
+ results["events_after"],
+ time_now,
+ bundle_aggregations=True,
)
results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
@@ -793,7 +773,7 @@ class BlockRoomRestServlet(RestServlet):
On GET: Get blocking status of room and user who has blocked this room.
"""
- PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index b295fb078b..15da9cd881 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.server_notices_manager = hs.get_server_notices_manager()
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
+ self.is_mine = hs.is_mine
def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
@@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
)
target_user = UserID.from_string(body["user_id"])
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
)
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index ca41fd45f2..7a6546372e 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
PATTERNS = admin_patterns("/statistics/users/media$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet):
await assert_requester_is_admin(self.auth, request)
order_by = parse_string(
- request, "order_by", default=UserSortOrder.USER_ID.value
+ request,
+ "order_by",
+ default=UserSortOrder.USER_ID.value,
+ allowed_values=(
+ UserSortOrder.MEDIA_LENGTH.value,
+ UserSortOrder.MEDIA_COUNT.value,
+ UserSortOrder.USER_ID.value,
+ UserSortOrder.DISPLAYNAME.value,
+ ),
)
- if order_by not in (
- UserSortOrder.MEDIA_LENGTH.value,
- UserSortOrder.MEDIA_COUNT.value,
- UserSortOrder.USER_ID.value,
- UserSortOrder.DISPLAYNAME.value,
- ):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown value for order_by: %s" % (order_by,),
- errcode=Codes.INVALID_PARAM,
- )
start = parse_integer(request, "from", default=0)
if start < 0:
diff --git a/synapse/rest/admin/username_available.py b/synapse/rest/admin/username_available.py
index 2bf1472967..5353dc3682 100644
--- a/synapse/rest/admin/username_available.py
+++ b/synapse/rest/admin/username_available.py
@@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
}
"""
- PATTERNS = admin_patterns("/username_available")
+ PATTERNS = admin_patterns("/username_available$")
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2a60b602b1..78e795c347 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet):
- PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = admin_patterns("/register")
+ PATTERNS = admin_patterns("/register$")
NONCE_TIMEOUT = 60
def __init__(self, hs: "HomeServer"):
@@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
]
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
if target_user != auth_user:
await assert_user_is_admin(self.auth, auth_user)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
ret = await self.admin_handler.get_whois(target_user)
@@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
@@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
@@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, target_user_id: str
@@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
# if not is_admin and target_user != auth_user:
# raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
term = parse_string(request, "term", required=True)
@@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
target_user = UserID.from_string(user_id)
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver",
@@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
assert_params_in_dict(body, ["admin"])
- if not self.hs.is_mine(target_user):
+ if not self.is_mine(target_user):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Only local users can be admins of this homeserver",
@@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
Get room list of an user.
"""
- PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
@@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.is_mine_id = hs.is_mine_id
async def on_POST(
self, request: SynapseRequest, user_id: str
@@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
)
@@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
{}
"""
- PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine_id = hs.is_mine_id
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
)
@@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
)
@@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
}
"""
- PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
def __init__(self, hs: "HomeServer"):
- self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self.is_mine_id = hs.is_mine_id
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
if not await self.store.get_user_by_id(user_id):
@@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
)
@@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
- if not self.hs.is_mine_id(user_id):
+ if not self.is_mine_id(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
)
@@ -1124,3 +1121,33 @@ class RateLimitRestServlet(RestServlet):
await self.store.delete_ratelimit_for_user(user_id)
return HTTPStatus.OK, {}
+
+
+class AccountDataRestServlet(RestServlet):
+ """Retrieve the given user's account data"""
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/accountdata")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastore()
+ self._is_mine_id = hs.is_mine_id
+
+ async def on_GET(
+ self, request: SynapseRequest, user_id: str
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ if not self._is_mine_id(user_id):
+ raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
+
+ if not await self._store.get_user_by_id(user_id):
+ raise NotFoundError("User not found")
+
+ global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+ return HTTPStatus.OK, {
+ "account_data": {
+ "global": global_data,
+ "rooms": by_room_data,
+ },
+ }
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 8566dc5cb5..ad6fd6492b 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -17,6 +17,7 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api import errors
+from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
@@ -24,10 +25,9 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict
-from ._base import client_patterns, interactive_auth_handler
-
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet):
device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
+ if device is None:
+ raise NotFoundError("No device found")
return 200, device
@interactive_auth_handler
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index d1d8a984c6..acd0c9e135 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -54,10 +55,10 @@ class NotificationsServlet(RestServlet):
)
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
- user_id, "m.read"
+ user_id, ReceiptTypes.READ
)
- notif_event_ids = [pa["event_id"] for pa in push_actions]
+ notif_event_ids = [pa.event_id for pa in push_actions]
notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = []
@@ -66,30 +67,30 @@ class NotificationsServlet(RestServlet):
for pa in push_actions:
returned_pa = {
- "room_id": pa["room_id"],
- "profile_tag": pa["profile_tag"],
- "actions": pa["actions"],
- "ts": pa["received_ts"],
+ "room_id": pa.room_id,
+ "profile_tag": pa.profile_tag,
+ "actions": pa.actions,
+ "ts": pa.received_ts,
"event": (
await self._event_serializer.serialize_event(
- notif_events[pa["event_id"]],
+ notif_events[pa.event_id],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
)
),
}
- if pa["room_id"] not in receipts_by_room:
+ if pa.room_id not in receipts_by_room:
returned_pa["read"] = False
else:
- receipt = receipts_by_room[pa["room_id"]]
+ receipt = receipts_by_room[pa.room_id]
returned_pa["read"] = (
receipt["topological_ordering"],
receipt["stream_ordering"],
- ) >= (pa["topological_ordering"], pa["stream_ordering"])
+ ) >= (pa.topological_ordering, pa.stream_ordering)
returned_push_actions.append(returned_pa)
- next_token = str(pa["stream_ordering"])
+ next_token = str(pa.stream_ordering)
return 200, {"notifications": returned_push_actions, "next_token": next_token}
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 43c04fac6f..f51be511d1 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
- read_event_id = body.get("m.read", None)
+ read_event_id = body.get(ReceiptTypes.READ, None)
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
if not isinstance(hidden, bool):
@@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet):
if read_event_id:
await self.receipts_handler.received_client_receipt(
room_id,
- "m.read",
+ ReceiptTypes.READ,
user_id=requester.user.to_string(),
event_id=read_event_id,
hidden=hidden,
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 2b25b9aad6..b24ad2d1be 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -16,7 +16,7 @@ import logging
import re
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.api.errors import Codes, SynapseError
from synapse.http import get_request_user_agent
from synapse.http.server import HttpServer
@@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- if receipt_type != "m.read":
+ if receipt_type != ReceiptTypes.READ:
raise SynapseError(400, "Receipt type must be 'm.read'")
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index fc4e6921c5..5815650ee6 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
+ room_id=room_id,
relation_type=relation_type,
event_type=event_type,
limit=limit,
@@ -231,7 +232,9 @@ class RelationPaginationServlet(RestServlet):
)
# The relations returned for the requested event do include their
# bundled aggregations.
- serialized_events = await self._event_serializer.serialize_events(events, now)
+ serialized_events = await self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=True
+ )
return_value = pagination_chunk.to_dict()
return_value["chunk"] = serialized_events
@@ -317,6 +320,7 @@ class RelationAggregationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
+ room_id=room_id,
event_type=event_type,
limit=limit,
from_token=from_token,
@@ -383,7 +387,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
- await self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -402,6 +408,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
result = await self.store.get_relations_for_event(
event_id=parent_id,
+ room_id=room_id,
relation_type=relation_type,
event_type=event_type,
aggregation_key=key,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index f48e2e6ca2..40330749e5 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
state_key: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if txn_id:
set_tag("txn_id", txn_id)
@@ -662,7 +662,9 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event_dict = await self._event_serializer.serialize_event(event, time_now)
+ event_dict = await self._event_serializer.serialize_event(
+ event, time_now, bundle_aggregations=True
+ )
return 200, event_dict
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -707,13 +709,13 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results["events_before"] = await self._event_serializer.serialize_events(
- results["events_before"], time_now
+ results["events_before"], time_now, bundle_aggregations=True
)
results["event"] = await self._event_serializer.serialize_event(
- results["event"], time_now
+ results["event"], time_now, bundle_aggregations=True
)
results["events_after"] = await self._event_serializer.serialize_events(
- results["events_after"], time_now
+ results["events_after"], time_now, bundle_aggregations=True
)
results["state"] = await self._event_serializer.serialize_events(
results["state"], time_now
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index e556ff93e6..e99a943d0d 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
+from synapse.logging.opentracing import trace
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -222,6 +223,7 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
+ @trace(opname="sync.encode_response")
async def encode_response(
self,
time_now: int,
@@ -293,6 +295,9 @@ class SyncRestServlet(RestServlet):
response[
"org.matrix.msc2732.device_unused_fallback_key_types"
] = sync_result.device_unused_fallback_key_types
+ response[
+ "device_unused_fallback_key_types"
+ ] = sync_result.device_unused_fallback_key_types
if joined:
response["rooms"][Membership.JOIN] = joined
@@ -329,6 +334,7 @@ class SyncRestServlet(RestServlet):
]
}
+ @trace(opname="sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
@@ -365,6 +371,7 @@ class SyncRestServlet(RestServlet):
return joined
+ @trace(opname="sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
@@ -403,6 +410,7 @@ class SyncRestServlet(RestServlet):
return invited
+ @trace(opname="sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
@@ -457,6 +465,7 @@ class SyncRestServlet(RestServlet):
return knocked
+ @trace(opname="sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],
@@ -528,6 +537,8 @@ class SyncRestServlet(RestServlet):
# overhead for initialsyncs. We need to figure out a way that the
# bundling can be done *before* the events are stored in the
# SyncResponseCache so that this part can be synchronous.
+ #
+ # Ensure to re-enable the test at tests/rest/client/test_relations.py::RelationsTestCase.test_bundled_aggregations.
bundle_aggregations=False,
token_id=token_id,
event_format=event_formatter,
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 8d888f4565..2290c57c12 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -93,6 +93,10 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving hidden read receipts as per MSC2285
"org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+ # Adds support for importing historical messages as per MSC2716
+ "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
+ # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
+ "org.matrix.msc3030": self.config.experimental.msc3030_enabled,
},
},
)
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 12b3ae120c..b9bfbea21b 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@@ -99,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request: Request) -> int:
+ def render_GET(self, request: Request) -> Optional[int]:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 244ba261bb..71b9a34b14 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -739,14 +739,21 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails: Dict[Tuple[int, int, str], str] = {}
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "crop":
- thumbnails.setdefault((r_width, r_height, r_type), r_method)
- elif r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ for requirement in requirements:
+ if requirement.method == "crop":
+ thumbnails.setdefault(
+ (requirement.width, requirement.height, requirement.media_type),
+ requirement.method,
+ )
+ elif requirement.method == "scale":
+ t_width, t_height = thumbnailer.aspect(
+ requirement.width, requirement.height
+ )
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- thumbnails[(t_width, t_height, r_type)] = r_method
+ thumbnails[
+ (t_width, t_height, requirement.media_type)
+ ] = requirement.method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items():
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2a59552c20..cce1527ed9 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional
import attr
+from synapse.rest.media.v1.preview_html import parse_html_description
from synapse.types import JsonDict
from synapse.util import json_decoder
@@ -245,8 +246,6 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
if video_urls:
open_graph_response["og:video"] = video_urls[0]
- from synapse.rest.media.v1.preview_url_resource import _calc_description
-
- description = _calc_description(tree)
+ description = parse_html_description(tree)
if description:
open_graph_response["og:description"] = description
diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py
new file mode 100644
index 0000000000..30b067dd42
--- /dev/null
+++ b/synapse/rest/media/v1/preview_html.py
@@ -0,0 +1,397 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import codecs
+import itertools
+import logging
+import re
+from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Union
+from urllib import parse as urlparse
+
+if TYPE_CHECKING:
+ from lxml import etree
+
+logger = logging.getLogger(__name__)
+
+_charset_match = re.compile(
+ br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
+)
+_xml_encoding_match = re.compile(
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
+)
+_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+
+
+def _normalise_encoding(encoding: str) -> Optional[str]:
+ """Use the Python codec's name as the normalised entry."""
+ try:
+ return codecs.lookup(encoding).name
+ except LookupError:
+ return None
+
+
+def _get_html_media_encodings(
+ body: bytes, content_type: Optional[str]
+) -> Iterable[str]:
+ """
+ Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
+
+ The precedence used for finding a character encoding is:
+
+ 1. <meta> tag with a charset declared.
+ 2. The XML document's character encoding attribute.
+ 3. The Content-Type header.
+ 4. Fallback to utf-8.
+ 5. Fallback to windows-1252.
+
+ This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
+
+ Args:
+ body: The HTML document, as bytes.
+ content_type: The Content-Type header.
+
+ Returns:
+ The character encoding of the body, as a string.
+ """
+ # There's no point in returning an encoding more than once.
+ attempted_encodings: Set[str] = set()
+
+ # Limit searches to the first 1kb, since it ought to be at the top.
+ body_start = body[:1024]
+
+ # Check if it has an encoding set in a meta tag.
+ match = _charset_match.search(body_start)
+ if match:
+ encoding = _normalise_encoding(match.group(1).decode("ascii"))
+ if encoding:
+ attempted_encodings.add(encoding)
+ yield encoding
+
+ # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+ # Check if it has an XML document with an encoding.
+ match = _xml_encoding_match.match(body_start)
+ if match:
+ encoding = _normalise_encoding(match.group(1).decode("ascii"))
+ if encoding and encoding not in attempted_encodings:
+ attempted_encodings.add(encoding)
+ yield encoding
+
+ # Check the HTTP Content-Type header for a character set.
+ if content_type:
+ content_match = _content_type_match.match(content_type)
+ if content_match:
+ encoding = _normalise_encoding(content_match.group(1))
+ if encoding and encoding not in attempted_encodings:
+ attempted_encodings.add(encoding)
+ yield encoding
+
+ # Finally, fallback to UTF-8, then windows-1252.
+ for fallback in ("utf-8", "cp1252"):
+ if fallback not in attempted_encodings:
+ yield fallback
+
+
+def decode_body(
+ body: bytes, uri: str, content_type: Optional[str] = None
+) -> Optional["etree.Element"]:
+ """
+ This uses lxml to parse the HTML document.
+
+ Args:
+ body: The HTML document, as bytes.
+ uri: The URI used to download the body.
+ content_type: The Content-Type header.
+
+ Returns:
+ The parsed HTML body, or None if an error occurred during processed.
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not body:
+ return None
+
+ # The idea here is that multiple encodings are tried until one works.
+ # Unfortunately the result is never used and then LXML will decode the string
+ # again with the found encoding.
+ for encoding in _get_html_media_encodings(body, content_type):
+ try:
+ body.decode(encoding)
+ except Exception:
+ pass
+ else:
+ break
+ else:
+ logger.warning("Unable to decode HTML body for %s", uri)
+ return None
+
+ from lxml import etree
+
+ # Create an HTML parser.
+ parser = etree.HTMLParser(recover=True, encoding=encoding)
+
+ # Attempt to parse the body. Returns None if the body was successfully
+ # parsed, but no tree was found.
+ return etree.fromstring(body, parser)
+
+
+def parse_html_to_open_graph(
+ tree: "etree.Element", media_uri: str
+) -> Dict[str, Optional[str]]:
+ """
+ Parse the HTML document into an Open Graph response.
+
+ This uses lxml to search the HTML document for Open Graph data (or
+ synthesizes it from the document).
+
+ Args:
+ tree: The parsed HTML document.
+ media_url: The URI used to download the body.
+
+ Returns:
+ The Open Graph response as a dictionary.
+ """
+
+ # if we see any image URLs in the OG response, then spider them
+ # (although the client could choose to do this by asking for previews of those
+ # URLs to avoid DoSing the server)
+
+ # "og:type" : "video",
+ # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
+ # "og:site_name" : "YouTube",
+ # "og:video:type" : "application/x-shockwave-flash",
+ # "og:description" : "Fun stuff happening here",
+ # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
+ # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
+ # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
+ # "og:video:width" : "1280"
+ # "og:video:height" : "720",
+ # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
+
+ og: Dict[str, Optional[str]] = {}
+ for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
+ if "content" in tag.attrib:
+ # if we've got more than 50 tags, someone is taking the piss
+ if len(og) >= 50:
+ logger.warning("Skipping OG for page with too many 'og:' tags")
+ return {}
+ og[tag.attrib["property"]] = tag.attrib["content"]
+
+ # TODO: grab article: meta tags too, e.g.:
+
+ # "article:publisher" : "https://www.facebook.com/thethudonline" />
+ # "article:author" content="https://www.facebook.com/thethudonline" />
+ # "article:tag" content="baby" />
+ # "article:section" content="Breaking News" />
+ # "article:published_time" content="2016-03-31T19:58:24+00:00" />
+ # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
+
+ if "og:title" not in og:
+ # do some basic spidering of the HTML
+ title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
+ if title and title[0].text is not None:
+ og["og:title"] = title[0].text.strip()
+ else:
+ og["og:title"] = None
+
+ if "og:image" not in og:
+ # TODO: extract a favicon failing all else
+ meta_image = tree.xpath(
+ "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
+ )
+ if meta_image:
+ og["og:image"] = rebase_url(meta_image[0], media_uri)
+ else:
+ # TODO: consider inlined CSS styles as well as width & height attribs
+ images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
+ images = sorted(
+ images,
+ key=lambda i: (
+ -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+ ),
+ )
+ if not images:
+ images = tree.xpath("//img[@src]")
+ if images:
+ og["og:image"] = images[0].attrib["src"]
+
+ if "og:description" not in og:
+ meta_description = tree.xpath(
+ "//*/meta"
+ "[translate(@name, 'DESCRIPTION', 'description')='description']"
+ "/@content"
+ )
+ if meta_description:
+ og["og:description"] = meta_description[0]
+ else:
+ og["og:description"] = parse_html_description(tree)
+ elif og["og:description"]:
+ # This must be a non-empty string at this point.
+ assert isinstance(og["og:description"], str)
+ og["og:description"] = summarize_paragraphs([og["og:description"]])
+
+ # TODO: delete the url downloads to stop diskfilling,
+ # as we only ever cared about its OG
+ return og
+
+
+def parse_html_description(tree: "etree.Element") -> Optional[str]:
+ """
+ Calculate a text description based on an HTML document.
+
+ Grabs any text nodes which are inside the <body/> tag, unless they are within
+ an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
+ if they are within a <script/> or <style/> tag.
+
+ This is a very very very coarse approximation to a plain text render of the page.
+
+ Args:
+ tree: The parsed HTML document.
+
+ Returns:
+ The plain text description, or None if one cannot be generated.
+ """
+ # We don't just use XPATH here as that is slow on some machines.
+
+ from lxml import etree
+
+ TAGS_TO_REMOVE = (
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment,
+ )
+
+ # Split all the text nodes into paragraphs (by splitting on new
+ # lines)
+ text_nodes = (
+ re.sub(r"\s+", "\n", el).strip()
+ for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
+ )
+ return summarize_paragraphs(text_nodes)
+
+
+def _iterate_over_text(
+ tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
+ """Iterate over the tree returning text nodes in a depth first fashion,
+ skipping text nodes inside certain tags.
+ """
+ # This is basically a stack that we extend using itertools.chain.
+ # This will either consist of an element to iterate over *or* a string
+ # to be returned.
+ elements = iter([tree])
+ while True:
+ el = next(elements, None)
+ if el is None:
+ return
+
+ if isinstance(el, str):
+ yield el
+ elif el.tag not in tags_to_ignore:
+ # el.text is the text before the first child, so we can immediately
+ # return it if the text exists.
+ if el.text:
+ yield el.text
+
+ # We add to the stack all the elements children, interspersed with
+ # each child's tail text (if it exists). The tail text of a node
+ # is text that comes *after* the node, so we always include it even
+ # if we ignore the child node.
+ elements = itertools.chain(
+ itertools.chain.from_iterable( # Basically a flatmap
+ [child, child.tail] if child.tail else [child]
+ for child in el.iterchildren()
+ ),
+ elements,
+ )
+
+
+def rebase_url(url: str, base: str) -> str:
+ base_parts = list(urlparse.urlparse(base))
+ url_parts = list(urlparse.urlparse(url))
+ if not url_parts[0]: # fix up schema
+ url_parts[0] = base_parts[0] or "http"
+ if not url_parts[1]: # fix up hostname
+ url_parts[1] = base_parts[1]
+ if not url_parts[2].startswith("/"):
+ url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+ return urlparse.urlunparse(url_parts)
+
+
+def summarize_paragraphs(
+ text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
+ """
+ Try to get a summary respecting first paragraph and then word boundaries.
+
+ Args:
+ text_nodes: The paragraphs to summarize.
+ min_size: The minimum number of words to include.
+ max_size: The maximum number of words to include.
+
+ Returns:
+ A summary of the text nodes, or None if that was not possible.
+ """
+
+ # TODO: Respect sentences?
+
+ description = ""
+
+ # Keep adding paragraphs until we get to the MIN_SIZE.
+ for text_node in text_nodes:
+ if len(description) < min_size:
+ text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+ description += text_node + "\n\n"
+ else:
+ break
+
+ description = description.strip()
+ description = re.sub(r"[\t ]+", " ", description)
+ description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
+
+ # If the concatenation of paragraphs to get above MIN_SIZE
+ # took us over MAX_SIZE, then we need to truncate mid paragraph
+ if len(description) > max_size:
+ new_desc = ""
+
+ # This splits the paragraph into words, but keeping the
+ # (preceding) whitespace intact so we can easily concat
+ # words back together.
+ for match in re.finditer(r"\s*\S+", description):
+ word = match.group()
+
+ # Keep adding words while the total length is less than
+ # MAX_SIZE.
+ if len(word) + len(new_desc) < max_size:
+ new_desc += word
+ else:
+ # At this point the next word *will* take us over
+ # MAX_SIZE, but we also want to ensure that its not
+ # a huge word. If it is add it anyway and we'll
+ # truncate later.
+ if len(new_desc) < min_size:
+ new_desc += word
+ break
+
+ # Double check that we're not over the limit
+ if len(new_desc) > max_size:
+ new_desc = new_desc[:max_size]
+
+ # We always add an ellipsis because at the very least
+ # we chopped mid paragraph.
+ description = new_desc.strip() + "…"
+ return description if description else None
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 054f3c296d..a3829d943b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,18 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import codecs
import datetime
import errno
import fnmatch
-import itertools
import logging
import os
import re
import shutil
import sys
import traceback
-from typing import TYPE_CHECKING, Dict, Generator, Iterable, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from urllib import parse as urlparse
import attr
@@ -45,6 +43,11 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.rest.media.v1.preview_html import (
+ decode_body,
+ parse_html_to_open_graph,
+ rebase_url,
+)
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
@@ -54,21 +57,11 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
if TYPE_CHECKING:
- from lxml import etree
-
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-_charset_match = re.compile(
- br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
-)
-_xml_encoding_match = re.compile(
- br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
-)
-_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
-
OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
@@ -311,7 +304,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# If there was no oEmbed URL (or oEmbed parsing failed), attempt
# to generate the Open Graph information from the HTML.
if not oembed_url or not og:
- og = _calc_og(tree, media_info.uri)
+ og = parse_html_to_open_graph(tree, media_info.uri)
await self._precache_image_url(user, media_info, og)
else:
@@ -468,7 +461,7 @@ class PreviewUrlResource(DirectServeJsonResource):
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
image_info = await self._download_url(
- _rebase_url(og["og:image"], media_info.uri), user
+ rebase_url(og["og:image"], media_info.uri), user
)
if _is_media(image_info.media_type):
@@ -632,301 +625,6 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
-def _normalise_encoding(encoding: str) -> Optional[str]:
- """Use the Python codec's name as the normalised entry."""
- try:
- return codecs.lookup(encoding).name
- except LookupError:
- return None
-
-
-def get_html_media_encodings(body: bytes, content_type: Optional[str]) -> Iterable[str]:
- """
- Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
-
- The precedence used for finding a character encoding is:
-
- 1. <meta> tag with a charset declared.
- 2. The XML document's character encoding attribute.
- 3. The Content-Type header.
- 4. Fallback to utf-8.
- 5. Fallback to windows-1252.
-
- This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
-
- Args:
- body: The HTML document, as bytes.
- content_type: The Content-Type header.
-
- Returns:
- The character encoding of the body, as a string.
- """
- # There's no point in returning an encoding more than once.
- attempted_encodings: Set[str] = set()
-
- # Limit searches to the first 1kb, since it ought to be at the top.
- body_start = body[:1024]
-
- # Check if it has an encoding set in a meta tag.
- match = _charset_match.search(body_start)
- if match:
- encoding = _normalise_encoding(match.group(1).decode("ascii"))
- if encoding:
- attempted_encodings.add(encoding)
- yield encoding
-
- # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
-
- # Check if it has an XML document with an encoding.
- match = _xml_encoding_match.match(body_start)
- if match:
- encoding = _normalise_encoding(match.group(1).decode("ascii"))
- if encoding and encoding not in attempted_encodings:
- attempted_encodings.add(encoding)
- yield encoding
-
- # Check the HTTP Content-Type header for a character set.
- if content_type:
- content_match = _content_type_match.match(content_type)
- if content_match:
- encoding = _normalise_encoding(content_match.group(1))
- if encoding and encoding not in attempted_encodings:
- attempted_encodings.add(encoding)
- yield encoding
-
- # Finally, fallback to UTF-8, then windows-1252.
- for fallback in ("utf-8", "cp1252"):
- if fallback not in attempted_encodings:
- yield fallback
-
-
-def decode_body(
- body: bytes, uri: str, content_type: Optional[str] = None
-) -> Optional["etree.Element"]:
- """
- This uses lxml to parse the HTML document.
-
- Args:
- body: The HTML document, as bytes.
- uri: The URI used to download the body.
- content_type: The Content-Type header.
-
- Returns:
- The parsed HTML body, or None if an error occurred during processed.
- """
- # If there's no body, nothing useful is going to be found.
- if not body:
- return None
-
- # The idea here is that multiple encodings are tried until one works.
- # Unfortunately the result is never used and then LXML will decode the string
- # again with the found encoding.
- for encoding in get_html_media_encodings(body, content_type):
- try:
- body.decode(encoding)
- except Exception:
- pass
- else:
- break
- else:
- logger.warning("Unable to decode HTML body for %s", uri)
- return None
-
- from lxml import etree
-
- # Create an HTML parser.
- parser = etree.HTMLParser(recover=True, encoding=encoding)
-
- # Attempt to parse the body. Returns None if the body was successfully
- # parsed, but no tree was found.
- return etree.fromstring(body, parser)
-
-
-def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
- """
- Calculate metadata for an HTML document.
-
- This uses lxml to search the HTML document for Open Graph data.
-
- Args:
- tree: The parsed HTML document.
- media_url: The URI used to download the body.
-
- Returns:
- The Open Graph response as a dictionary.
- """
-
- # if we see any image URLs in the OG response, then spider them
- # (although the client could choose to do this by asking for previews of those
- # URLs to avoid DoSing the server)
-
- # "og:type" : "video",
- # "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
- # "og:site_name" : "YouTube",
- # "og:video:type" : "application/x-shockwave-flash",
- # "og:description" : "Fun stuff happening here",
- # "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
- # "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
- # "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
- # "og:video:width" : "1280"
- # "og:video:height" : "720",
- # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
-
- og: Dict[str, Optional[str]] = {}
- for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
- if "content" in tag.attrib:
- # if we've got more than 50 tags, someone is taking the piss
- if len(og) >= 50:
- logger.warning("Skipping OG for page with too many 'og:' tags")
- return {}
- og[tag.attrib["property"]] = tag.attrib["content"]
-
- # TODO: grab article: meta tags too, e.g.:
-
- # "article:publisher" : "https://www.facebook.com/thethudonline" />
- # "article:author" content="https://www.facebook.com/thethudonline" />
- # "article:tag" content="baby" />
- # "article:section" content="Breaking News" />
- # "article:published_time" content="2016-03-31T19:58:24+00:00" />
- # "article:modified_time" content="2016-04-01T18:31:53+00:00" />
-
- if "og:title" not in og:
- # do some basic spidering of the HTML
- title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
- if title and title[0].text is not None:
- og["og:title"] = title[0].text.strip()
- else:
- og["og:title"] = None
-
- if "og:image" not in og:
- # TODO: extract a favicon failing all else
- meta_image = tree.xpath(
- "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
- )
- if meta_image:
- og["og:image"] = _rebase_url(meta_image[0], media_uri)
- else:
- # TODO: consider inlined CSS styles as well as width & height attribs
- images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
- images = sorted(
- images,
- key=lambda i: (
- -1 * float(i.attrib["width"]) * float(i.attrib["height"])
- ),
- )
- if not images:
- images = tree.xpath("//img[@src]")
- if images:
- og["og:image"] = images[0].attrib["src"]
-
- if "og:description" not in og:
- meta_description = tree.xpath(
- "//*/meta"
- "[translate(@name, 'DESCRIPTION', 'description')='description']"
- "/@content"
- )
- if meta_description:
- og["og:description"] = meta_description[0]
- else:
- og["og:description"] = _calc_description(tree)
- elif og["og:description"]:
- # This must be a non-empty string at this point.
- assert isinstance(og["og:description"], str)
- og["og:description"] = summarize_paragraphs([og["og:description"]])
-
- # TODO: delete the url downloads to stop diskfilling,
- # as we only ever cared about its OG
- return og
-
-
-def _calc_description(tree: "etree.Element") -> Optional[str]:
- """
- Calculate a text description based on an HTML document.
-
- Grabs any text nodes which are inside the <body/> tag, unless they are within
- an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
- if they are within a <script/> or <style/> tag.
-
- This is a very very very coarse approximation to a plain text render of the page.
-
- Args:
- tree: The parsed HTML document.
-
- Returns:
- The plain text description, or None if one cannot be generated.
- """
- # We don't just use XPATH here as that is slow on some machines.
-
- from lxml import etree
-
- TAGS_TO_REMOVE = (
- "header",
- "nav",
- "aside",
- "footer",
- "script",
- "noscript",
- "style",
- etree.Comment,
- )
-
- # Split all the text nodes into paragraphs (by splitting on new
- # lines)
- text_nodes = (
- re.sub(r"\s+", "\n", el).strip()
- for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
- )
- return summarize_paragraphs(text_nodes)
-
-
-def _iterate_over_text(
- tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
-) -> Generator[str, None, None]:
- """Iterate over the tree returning text nodes in a depth first fashion,
- skipping text nodes inside certain tags.
- """
- # This is basically a stack that we extend using itertools.chain.
- # This will either consist of an element to iterate over *or* a string
- # to be returned.
- elements = iter([tree])
- while True:
- el = next(elements, None)
- if el is None:
- return
-
- if isinstance(el, str):
- yield el
- elif el.tag not in tags_to_ignore:
- # el.text is the text before the first child, so we can immediately
- # return it if the text exists.
- if el.text:
- yield el.text
-
- # We add to the stack all the elements children, interspersed with
- # each child's tail text (if it exists). The tail text of a node
- # is text that comes *after* the node, so we always include it even
- # if we ignore the child node.
- elements = itertools.chain(
- itertools.chain.from_iterable( # Basically a flatmap
- [child, child.tail] if child.tail else [child]
- for child in el.iterchildren()
- ),
- elements,
- )
-
-
-def _rebase_url(url: str, base: str) -> str:
- base_parts = list(urlparse.urlparse(base))
- url_parts = list(urlparse.urlparse(url))
- if not url_parts[0]: # fix up schema
- url_parts[0] = base_parts[0] or "http"
- if not url_parts[1]: # fix up hostname
- url_parts[1] = base_parts[1]
- if not url_parts[2].startswith("/"):
- url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
- return urlparse.urlunparse(url_parts)
-
-
def _is_media(content_type: str) -> bool:
return content_type.lower().startswith("image/")
@@ -940,68 +638,3 @@ def _is_html(content_type: str) -> bool:
def _is_json(content_type: str) -> bool:
return content_type.lower().startswith("application/json")
-
-
-def summarize_paragraphs(
- text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
-) -> Optional[str]:
- """
- Try to get a summary respecting first paragraph and then word boundaries.
-
- Args:
- text_nodes: The paragraphs to summarize.
- min_size: The minimum number of words to include.
- max_size: The maximum number of words to include.
-
- Returns:
- A summary of the text nodes, or None if that was not possible.
- """
-
- # TODO: Respect sentences?
-
- description = ""
-
- # Keep adding paragraphs until we get to the MIN_SIZE.
- for text_node in text_nodes:
- if len(description) < min_size:
- text_node = re.sub(r"[\t \r\n]+", " ", text_node)
- description += text_node + "\n\n"
- else:
- break
-
- description = description.strip()
- description = re.sub(r"[\t ]+", " ", description)
- description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
-
- # If the concatenation of paragraphs to get above MIN_SIZE
- # took us over MAX_SIZE, then we need to truncate mid paragraph
- if len(description) > max_size:
- new_desc = ""
-
- # This splits the paragraph into words, but keeping the
- # (preceding) whitespace intact so we can easily concat
- # words back together.
- for match in re.finditer(r"\s*\S+", description):
- word = match.group()
-
- # Keep adding words while the total length is less than
- # MAX_SIZE.
- if len(word) + len(new_desc) < max_size:
- new_desc += word
- else:
- # At this point the next word *will* take us over
- # MAX_SIZE, but we also want to ensure that its not
- # a huge word. If it is add it anyway and we'll
- # truncate later.
- if len(new_desc) < min_size:
- new_desc += word
- break
-
- # Double check that we're not over the limit
- if len(new_desc) > max_size:
- new_desc = new_desc[:max_size]
-
- # We always add an ellipsis because at the very least
- # we chopped mid paragraph.
- description = new_desc.strip() + "…"
- return description if description else None
|