diff --git a/changelog.d/9150.feature b/changelog.d/9150.feature
new file mode 100644
index 0000000000..48a8148dee
--- /dev/null
+++ b/changelog.d/9150.feature
@@ -0,0 +1 @@
+New API /_synapse/admin/rooms/{roomId}/context/{eventId}.
diff --git a/changelog.d/9299.misc b/changelog.d/9299.misc
new file mode 100644
index 0000000000..c883a677ed
--- /dev/null
+++ b/changelog.d/9299.misc
@@ -0,0 +1 @@
+Update the `Cursor` type hints to better match PEP 249.
diff --git a/changelog.d/9321.bugfix b/changelog.d/9321.bugfix
new file mode 100644
index 0000000000..52eed80969
--- /dev/null
+++ b/changelog.d/9321.bugfix
@@ -0,0 +1 @@
+Assert a maximum length for the `client_secret` parameter for spec compliance.
diff --git a/changelog.d/9333.bugfix b/changelog.d/9333.bugfix
new file mode 100644
index 0000000000..c34ba378c5
--- /dev/null
+++ b/changelog.d/9333.bugfix
@@ -0,0 +1 @@
+Fix additional errors when previewing URLs: "AttributeError 'NoneType' object has no attribute 'xpath'" and "ValueError: Unicode strings with encoding declaration are not supported. Please use bytes input or XML fragments without declaration.".
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 3832b36407..bc737b30f5 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -10,6 +10,7 @@
* [Undoing room shutdowns](#undoing-room-shutdowns)
- [Make Room Admin API](#make-room-admin-api)
- [Forward Extremities Admin API](#forward-extremities-admin-api)
+- [Event Context API](#event-context-api)
# List Room API
@@ -594,3 +595,121 @@ that were deleted.
"deleted": 1
}
```
+
+# Event Context API
+
+This API lets a client find the context of an event. This is designed primarily to investigate abuse reports.
+
+```
+GET /_synapse/admin/v1/rooms/<room_id>/context/<event_id>
+```
+
+This API mimmicks [GET /_matrix/client/r0/rooms/{roomId}/context/{eventId}](https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-rooms-roomid-context-eventid). Please refer to the link for all details on parameters and reseponse.
+
+Example response:
+
+```json
+{
+ "end": "t29-57_2_0_2",
+ "events_after": [
+ {
+ "content": {
+ "body": "This is an example text message",
+ "msgtype": "m.text",
+ "format": "org.matrix.custom.html",
+ "formatted_body": "<b>This is an example text message</b>"
+ },
+ "type": "m.room.message",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ }
+ }
+ ],
+ "event": {
+ "content": {
+ "body": "filename.jpg",
+ "info": {
+ "h": 398,
+ "w": 394,
+ "mimetype": "image/jpeg",
+ "size": 31037
+ },
+ "url": "mxc://example.org/JWEIFJgwEIhweiWJE",
+ "msgtype": "m.image"
+ },
+ "type": "m.room.message",
+ "event_id": "$f3h4d129462ha:example.com",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ }
+ },
+ "events_before": [
+ {
+ "content": {
+ "body": "something-important.doc",
+ "filename": "something-important.doc",
+ "info": {
+ "mimetype": "application/msword",
+ "size": 46144
+ },
+ "msgtype": "m.file",
+ "url": "mxc://example.org/FHyPlCeYUSFFxlgbQYZmoEoe"
+ },
+ "type": "m.room.message",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ }
+ }
+ ],
+ "start": "t27-54_2_0_2",
+ "state": [
+ {
+ "content": {
+ "creator": "@example:example.org",
+ "room_version": "1",
+ "m.federate": true,
+ "predecessor": {
+ "event_id": "$something:example.org",
+ "room_id": "!oldroom:example.org"
+ }
+ },
+ "type": "m.room.create",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ },
+ "state_key": ""
+ },
+ {
+ "content": {
+ "membership": "join",
+ "avatar_url": "mxc://example.org/SEsfnsuifSDFSSEF",
+ "displayname": "Alice Margatroid"
+ },
+ "type": "m.room.member",
+ "event_id": "$143273582443PhrSn:example.org",
+ "room_id": "!636q39766251:example.com",
+ "sender": "@example:example.org",
+ "origin_server_ts": 1432735824653,
+ "unsigned": {
+ "age": 1234
+ },
+ "state_key": "@alice:example.org"
+ }
+ ]
+}
+```
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0d042cbfac..76bf52ea23 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -18,6 +18,7 @@
import logging
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
@@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
# TODO: Flairs
+# Note that the maximum lengths are somewhat arbitrary.
+MAX_SHORT_DESC_LEN = 1000
+MAX_LONG_DESC_LEN = 10000
+
+
class GroupsServerWorkerHandler:
def __init__(self, hs):
self.hs = hs
@@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
)
profile = {}
- for keyname in ("name", "avatar_url", "short_description", "long_description"):
+ for keyname, max_length in (
+ ("name", MAX_DISPLAYNAME_LEN),
+ ("avatar_url", MAX_AVATAR_URL_LEN),
+ ("short_description", MAX_SHORT_DESC_LEN),
+ ("long_description", MAX_LONG_DESC_LEN),
+ ):
if keyname in content:
value = content[keyname]
if not isinstance(value, str):
- raise SynapseError(400, "%r value is not a string" % (keyname,))
+ raise SynapseError(
+ 400,
+ "%r value is not a string" % (keyname,),
+ errcode=Codes.INVALID_PARAM,
+ )
+ if len(value) > max_length:
+ raise SynapseError(
+ 400,
+ "Invalid %s parameter" % (keyname,),
+ errcode=Codes.INVALID_PARAM,
+ )
profile[keyname] = value
await self.store.update_group_profile(group_id, profile)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d037742081..736070d574 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
+from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@@ -1025,41 +1026,51 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
+ self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(
self,
- user: UserID,
+ requester: Requester,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
+ use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user
+ requester
room_id
event_id
limit: The maximum number of events to return in total
(excluding state).
event_filter: the filter to apply to the events returned
(excluding the target event_id)
-
+ use_admin_priviledge: if `True`, return all events, regardless
+ of whether `user` has access to them. To be used **ONLY**
+ from the admin API.
Returns:
dict, or None if the event isn't found
"""
+ user = requester.user
+ if use_admin_priviledge:
+ await assert_user_is_admin(self.auth, requester.user)
+
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
- def filter_evts(events):
- return filter_events_for_client(
+ async def filter_evts(events):
+ if use_admin_priviledge:
+ return events
+ return await filter_events_for_client(
self.storage, user.to_string(), events, is_peeking=is_peeking
)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index f5c5d164f9..8457db1e22 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import (
JoinRoomAliasServlet,
ListRoomRestServlet,
MakeRoomAdminRestServlet,
+ RoomEventContextServlet,
RoomMembersRestServlet,
RoomRestServlet,
RoomStateRestServlet,
@@ -238,6 +239,7 @@ def register_servlets(hs, http_server):
MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)
ForwardExtremitiesRestServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index c7f5085470..acc8f9fa0a 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
+from urllib import parse as urlparse
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.filtering import Filter
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -34,6 +36,7 @@ from synapse.rest.admin._base import (
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -604,3 +607,65 @@ class ForwardExtremitiesRestServlet(RestServlet):
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}
+
+
+class RoomEventContextServlet(RestServlet):
+ """
+ Provide the context for an event.
+ This API is designed to be used when system administrators wish to look at
+ an abuse report and understand what happened during and immediately prior
+ to this event.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
+
+ def __init__(self, hs):
+ super().__init__()
+ self.clock = hs.get_clock()
+ self.room_context_handler = hs.get_room_context_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
+
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ limit = parse_integer(request, "limit", default=10)
+
+ # picking the API shape for symmetry with /messages
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
+ else:
+ event_filter = None
+
+ results = await self.room_context_handler.get_event_context(
+ requester,
+ room_id,
+ event_id,
+ limit,
+ event_filter,
+ use_admin_priviledge=True,
+ )
+
+ if not results:
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+
+ time_now = self.clock.time_msec()
+ results["events_before"] = await self._event_serializer.serialize_events(
+ results["events_before"], time_now
+ )
+ results["event"] = await self._event_serializer.serialize_event(
+ results["event"], time_now
+ )
+ results["events_after"] = await self._event_serializer.serialize_events(
+ results["events_after"], time_now
+ )
+ results["state"] = await self._event_serializer.serialize_events(
+ results["state"], time_now
+ )
+
+ return 200, results
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c8b128583f..b37f5aa873 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -648,7 +648,7 @@ class RoomEventContextServlet(RestServlet):
event_filter = None
results = await self.room_context_handler.get_event_context(
- requester.user, room_id, event_id, limit, event_filter
+ requester, room_id, event_id, limit, event_filter
)
if not results:
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..4fe712b30c 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,24 @@
import logging
from functools import wraps
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -33,7 +44,7 @@ def _validate_group_id(f):
"""
@wraps(f)
- def wrapper(self, request, group_id, *args, **kwargs):
+ def wrapper(self, request: Request, group_id: str, *args, **kwargs):
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@@ -48,14 +59,14 @@ class GroupServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -66,11 +77,15 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request, group_id):
+ async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ content, ("name", "avatar_url", "short_description", "long_description")
+ )
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
@@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, category_id):
+ async def on_GET(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
@@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, role_id):
+ async def on_GET(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
@@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet):
"/users/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id, user_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
@@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id, user_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str, user_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet):
PATTERNS = client_patterns("/create_group$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
@@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
@_validate_group_id
- async def on_DELETE(self, request, group_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet):
"/config/(?P<config_key>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id, config_key):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str, config_key: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/joined_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5d461efb6a..6b39e27f4c 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -195,6 +195,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
@@ -297,6 +298,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5ac307a62d..3e4566464b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,7 +58,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
+_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_xml_encoding_match = re.compile(
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
@@ -299,24 +302,7 @@ class PreviewUrlResource(DirectServeJsonResource):
with open(media_info["filename"], "rb") as file:
body = file.read()
- encoding = None
-
- # Let's try and figure out if it has an encoding set in a meta tag.
- # Limit it to the first 1kb, since it ought to be in the meta tags
- # at the top.
- match = _charset_match.search(body[:1000])
-
- # If we find a match, it should take precedence over the
- # Content-Type header, so set it here.
- if match:
- encoding = match.group(1).decode("ascii")
-
- # If we don't find a match, we'll look at the HTTP Content-Type, and
- # if that doesn't exist, we'll fall back to UTF-8.
- if not encoding:
- content_match = _content_type_match.match(media_info["media_type"])
- encoding = content_match.group(1) if content_match else "utf-8"
-
+ encoding = get_html_media_encoding(body, media_info["media_type"])
og = decode_and_calc_og(body, media_info["uri"], encoding)
# pre-cache the image for posterity
@@ -688,6 +674,48 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
+def get_html_media_encoding(body: bytes, content_type: str) -> str:
+ """
+ Get the encoding of the body based on the (presumably) HTML body or media_type.
+
+ The precedence used for finding a character encoding is:
+
+ 1. meta tag with a charset declared.
+ 2. The XML document's character encoding attribute.
+ 3. The Content-Type header.
+ 4. Fallback to UTF-8.
+
+ Args:
+ body: The HTML document, as bytes.
+ content_type: The Content-Type header.
+
+ Returns:
+ The character encoding of the body, as a string.
+ """
+ # Limit searches to the first 1kb, since it ought to be at the top.
+ body_start = body[:1024]
+
+ # Let's try and figure out if it has an encoding set in a meta tag.
+ match = _charset_match.search(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+ # If we didn't find a match, see if it an XML document with an encoding.
+ match = _xml_encoding_match.match(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # If we don't find a match, we'll look at the HTTP Content-Type, and
+ # if that doesn't exist, we'll fall back to UTF-8.
+ content_match = _content_type_match.match(content_type)
+ if content_match:
+ return content_match.group(1)
+
+ return "utf-8"
+
+
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
@@ -724,6 +752,11 @@ def decode_and_calc_og(
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
# Attempt to parse the body. If this fails, log and return no metadata.
tree = etree.fromstring(body_attempt, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return {}
+
return _calc_og(tree, media_uri)
# Attempt to parse the body. If this fails, log and return no metadata.
diff --git a/synapse/server.py b/synapse/server.py
index 6ffb7e0fd9..91d59b755a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -24,7 +24,17 @@
import abc
import functools
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
import twisted.internet.base
import twisted.internet.tcp
@@ -582,7 +592,9 @@ class HomeServer(metaclass=abc.ABCMeta):
return UserDirectoryHandler(self)
@cache_in_self
- def get_groups_local_handler(self):
+ def get_groups_local_handler(
+ self,
+ ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app:
return GroupsLocalWorkerHandler(self)
else:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d2ba4bd2fc..ae4bf1a54f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None:
self.conn.commit()
- def rollback(self, *args, **kwargs) -> None:
- self.conn.rollback(*args, **kwargs)
+ def rollback(self) -> None:
+ self.conn.rollback()
def __enter__(self) -> "Connection":
self.conn.__enter__()
@@ -244,12 +244,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
+ def fetchone(self) -> Optional[Tuple]:
+ return self.txn.fetchone()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+ return self.txn.fetchmany(size=size)
+
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
-
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@@ -754,6 +757,7 @@ class DatabasePool:
Returns:
A list of dicts where the key is the column header.
"""
+ assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19bae..28bb2eb662 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -619,9 +619,9 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
- current_version = int(row[0]) if row else None
- if current_version:
+ if row is not None:
+ current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
class Cursor(Protocol):
- def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...
- def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...
- def fetchall(self) -> List[Tuple]:
+ def fetchone(self) -> Optional[Tuple]:
+ ...
+
+ def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...
- def fetchone(self) -> Tuple:
+ def fetchall(self) -> List[Tuple]:
...
@property
- def description(self) -> Any:
- return None
+ def description(
+ self,
+ ) -> Optional[
+ Sequence[
+ # Note that this is an approximate typing based on sqlite3 and other
+ # drivers, and may not be entirely accurate.
+ Tuple[
+ str,
+ Optional[Any],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ ]
+ ]
+ ]:
+ ...
@property
def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
def commit(self) -> None:
...
- def rollback(self, *args, **kwargs) -> None:
+ def rollback(self) -> None:
...
def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 0ec4dc2918..e2b316a218 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
- return txn.fetchone()[0]
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ return fetch_res[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
@@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
- last_value, is_called = txn.fetchone()
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ last_value, is_called = fetch_res
# If we have an associated stream check the stream_positions table.
max_in_stream_positions = None
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f8038bf861..9ce7873ab5 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
rand = random.SystemRandom()
-def random_string(length):
+def random_string(length: int) -> str:
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
-def random_string_with_symbols(length):
+def random_string_with_symbols(length: int) -> str:
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
-def is_ascii(s):
- if isinstance(s, bytes):
- try:
- s.decode("ascii").encode("ascii")
- except UnicodeDecodeError:
- return False
- except UnicodeEncodeError:
- return False
- return True
+def is_ascii(s: bytes) -> bool:
+ try:
+ s.decode("ascii").encode("ascii")
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
+ return True
-def assert_valid_client_secret(client_secret):
- """Validate that a given string matches the client_secret regex defined by the spec"""
- if client_secret_regex.match(client_secret) is None:
+def assert_valid_client_secret(client_secret: str) -> None:
+ """Validate that a given string matches the client_secret defined by the spec"""
+ if (
+ len(client_secret) <= 0
+ or len(client_secret) > 255
+ or CLIENT_SECRET_REGEX.match(client_secret) is None
+ ):
raise SynapseError(
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 4a5df293a4..e39d02602a 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -80,6 +80,7 @@ async def filter_events_for_client(
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index 7c47aa7e0a..2a217b1ce0 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -1445,6 +1445,90 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
+ def test_context_as_non_admin(self):
+ """
+ Test that, without being admin, one cannot use the context admin API
+ """
+ # Create a room.
+ user_id = self.register_user("test", "test")
+ user_tok = self.login("test", "test")
+
+ self.register_user("test_2", "test")
+ user_tok_2 = self.login("test_2", "test")
+
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ # Populate the room with events.
+ events = []
+ for i in range(30):
+ events.append(
+ self.helper.send_event(
+ room_id, "com.example.test", content={"index": i}, tok=user_tok
+ )
+ )
+
+ # Now attempt to find the context using the admin API without being admin.
+ midway = (len(events) - 1) // 2
+ for tok in [user_tok, user_tok_2]:
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/context/%s"
+ % (room_id, events[midway]["event_id"]),
+ access_token=tok,
+ )
+ self.assertEquals(
+ 403, int(channel.result["code"]), msg=channel.result["body"]
+ )
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_context_as_admin(self):
+ """
+ Test that, as admin, we can find the context of an event without having joined the room.
+ """
+
+ # Create a room. We're not part of it.
+ user_id = self.register_user("test", "test")
+ user_tok = self.login("test", "test")
+ room_id = self.helper.create_room_as(user_id, tok=user_tok)
+
+ # Populate the room with events.
+ events = []
+ for i in range(30):
+ events.append(
+ self.helper.send_event(
+ room_id, "com.example.test", content={"index": i}, tok=user_tok
+ )
+ )
+
+ # Now let's fetch the context for this room.
+ midway = (len(events) - 1) // 2
+ channel = self.make_request(
+ "GET",
+ "/_synapse/admin/v1/rooms/%s/context/%s"
+ % (room_id, events[midway]["event_id"]),
+ access_token=self.admin_user_tok,
+ )
+ self.assertEquals(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEquals(
+ channel.json_body["event"]["event_id"], events[midway]["event_id"]
+ )
+
+ for i, found_event in enumerate(channel.json_body["events_before"]):
+ for j, posted_event in enumerate(events):
+ if found_event["event_id"] == posted_event["event_id"]:
+ self.assertTrue(j < midway)
+ break
+ else:
+ self.fail("Event %s from events_before not found" % j)
+
+ for i, found_event in enumerate(channel.json_body["events_after"]):
+ for j, posted_event in enumerate(events):
+ if found_event["event_id"] == posted_event["event_id"]:
+ self.assertTrue(j > midway)
+ break
+ else:
+ self.fail("Event %s from events_after not found" % j)
+
class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 0c6cbbd921..ea83299918 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,6 +15,7 @@
from synapse.rest.media.v1.preview_url_resource import (
decode_and_calc_og,
+ get_html_media_encoding,
summarize_paragraphs,
)
@@ -26,7 +27,7 @@ except ImportError:
lxml = None
-class PreviewTestCase(unittest.TestCase):
+class SummarizeTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
@@ -144,12 +145,12 @@ class PreviewTestCase(unittest.TestCase):
)
-class PreviewUrlTestCase(unittest.TestCase):
+class CalcOgTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
def test_simple(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -163,7 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -178,7 +179,7 @@ class PreviewUrlTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -202,7 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
)
def test_script(self):
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -217,7 +218,7 @@ class PreviewUrlTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
- html = """
+ html = b"""
<html>
<body>
Some text.
@@ -230,7 +231,7 @@ class PreviewUrlTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
- html = """
+ html = b"""
<html>
<meta property="og:description" content="Some text."/>
<body>
@@ -244,7 +245,7 @@ class PreviewUrlTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
- html = """
+ html = b"""
<html>
<body>
<h1><a href="foo"/></h1>
@@ -258,13 +259,20 @@ class PreviewUrlTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_empty(self):
- html = ""
+ """Test a body with no data in it."""
+ html = b""
+ og = decode_and_calc_og(html, "http://example.com/test.html")
+ self.assertEqual(og, {})
+
+ def test_no_tree(self):
+ """A valid body with no tree in it."""
+ html = b"\x00"
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {})
def test_invalid_encoding(self):
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
- html = """
+ html = b"""
<html>
<head><title>Foo</title></head>
<body>
@@ -290,3 +298,76 @@ class PreviewUrlTestCase(unittest.TestCase):
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
+
+
+class MediaEncodingTestCase(unittest.TestCase):
+ def test_meta_charset(self):
+ """A character encoding is found via the meta tag."""
+ encoding = get_html_media_encoding(
+ b"""
+ <html>
+ <head><meta charset="ascii">
+ </head>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "ascii")
+
+ # A less well-formed version.
+ encoding = get_html_media_encoding(
+ b"""
+ <html>
+ <head>< meta charset = ascii>
+ </head>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "ascii")
+
+ def test_xml_encoding(self):
+ """A character encoding is found via the meta tag."""
+ encoding = get_html_media_encoding(
+ b"""
+ <?xml version="1.0" encoding="ascii"?>
+ <html>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "ascii")
+
+ def test_meta_xml_encoding(self):
+ """Meta tags take precedence over XML encoding."""
+ encoding = get_html_media_encoding(
+ b"""
+ <?xml version="1.0" encoding="ascii"?>
+ <html>
+ <head><meta charset="UTF-16">
+ </head>
+ </html>
+ """,
+ "text/html",
+ )
+ self.assertEqual(encoding, "UTF-16")
+
+ def test_content_type(self):
+ """A character encoding is found via the Content-Type header."""
+ # Test a few variations of the header.
+ headers = (
+ 'text/html; charset="ascii";',
+ "text/html;charset=ascii;",
+ 'text/html; charset="ascii"',
+ "text/html; charset=ascii",
+ 'text/html; charset="ascii;',
+ 'text/html; charset=ascii";',
+ )
+ for header in headers:
+ encoding = get_html_media_encoding(b"", header)
+ self.assertEqual(encoding, "ascii")
+
+ def test_fallback(self):
+ """A character encoding cannot be found in the body or header."""
+ encoding = get_html_media_encoding(b"", "text/html")
+ self.assertEqual(encoding, "utf-8")
|