diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0d042cbfac..76bf52ea23 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -18,6 +18,7 @@
import logging
from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
@@ -32,6 +33,11 @@ logger = logging.getLogger(__name__)
# TODO: Flairs
+# Note that the maximum lengths are somewhat arbitrary.
+MAX_SHORT_DESC_LEN = 1000
+MAX_LONG_DESC_LEN = 10000
+
+
class GroupsServerWorkerHandler:
def __init__(self, hs):
self.hs = hs
@@ -508,11 +514,26 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
)
profile = {}
- for keyname in ("name", "avatar_url", "short_description", "long_description"):
+ for keyname, max_length in (
+ ("name", MAX_DISPLAYNAME_LEN),
+ ("avatar_url", MAX_AVATAR_URL_LEN),
+ ("short_description", MAX_SHORT_DESC_LEN),
+ ("long_description", MAX_LONG_DESC_LEN),
+ ):
if keyname in content:
value = content[keyname]
if not isinstance(value, str):
- raise SynapseError(400, "%r value is not a string" % (keyname,))
+ raise SynapseError(
+ 400,
+ "%r value is not a string" % (keyname,),
+ errcode=Codes.INVALID_PARAM,
+ )
+ if len(value) > max_length:
+ raise SynapseError(
+ 400,
+ "Invalid %s parameter" % (keyname,),
+ errcode=Codes.INVALID_PARAM,
+ )
profile[keyname] = value
await self.store.update_group_profile(group_id, profile)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d037742081..736070d574 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -38,6 +38,7 @@ from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
+from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
@@ -1025,41 +1026,51 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
+ self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(
self,
- user: UserID,
+ requester: Requester,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
+ use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
- user
+ requester
room_id
event_id
limit: The maximum number of events to return in total
(excluding state).
event_filter: the filter to apply to the events returned
(excluding the target event_id)
-
+ use_admin_priviledge: if `True`, return all events, regardless
+ of whether `user` has access to them. To be used **ONLY**
+ from the admin API.
Returns:
dict, or None if the event isn't found
"""
+ user = requester.user
+ if use_admin_priviledge:
+ await assert_user_is_admin(self.auth, requester.user)
+
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
- def filter_evts(events):
- return filter_events_for_client(
+ async def filter_evts(events):
+ if use_admin_priviledge:
+ return events
+ return await filter_events_for_client(
self.storage, user.to_string(), events, is_peeking=is_peeking
)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index f5c5d164f9..8457db1e22 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -42,6 +42,7 @@ from synapse.rest.admin.rooms import (
JoinRoomAliasServlet,
ListRoomRestServlet,
MakeRoomAdminRestServlet,
+ RoomEventContextServlet,
RoomMembersRestServlet,
RoomRestServlet,
RoomStateRestServlet,
@@ -238,6 +239,7 @@ def register_servlets(hs, http_server):
MakeRoomAdminRestServlet(hs).register(http_server)
ShadowBanRestServlet(hs).register(http_server)
ForwardExtremitiesRestServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index c7f5085470..acc8f9fa0a 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
+from urllib import parse as urlparse
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.filtering import Filter
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -34,6 +36,7 @@ from synapse.rest.admin._base import (
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -604,3 +607,65 @@ class ForwardExtremitiesRestServlet(RestServlet):
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}
+
+
+class RoomEventContextServlet(RestServlet):
+ """
+ Provide the context for an event.
+ This API is designed to be used when system administrators wish to look at
+ an abuse report and understand what happened during and immediately prior
+ to this event.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
+
+ def __init__(self, hs):
+ super().__init__()
+ self.clock = hs.get_clock()
+ self.room_context_handler = hs.get_room_context_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
+
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ limit = parse_integer(request, "limit", default=10)
+
+ # picking the API shape for symmetry with /messages
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
+ else:
+ event_filter = None
+
+ results = await self.room_context_handler.get_event_context(
+ requester,
+ room_id,
+ event_id,
+ limit,
+ event_filter,
+ use_admin_priviledge=True,
+ )
+
+ if not results:
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+
+ time_now = self.clock.time_msec()
+ results["events_before"] = await self._event_serializer.serialize_events(
+ results["events_before"], time_now
+ )
+ results["event"] = await self._event_serializer.serialize_event(
+ results["event"], time_now
+ )
+ results["events_after"] = await self._event_serializer.serialize_events(
+ results["events_after"], time_now
+ )
+ results["state"] = await self._event_serializer.serialize_events(
+ results["state"], time_now
+ )
+
+ return 200, results
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c8b128583f..b37f5aa873 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -648,7 +648,7 @@ class RoomEventContextServlet(RestServlet):
event_filter = None
results = await self.room_context_handler.get_event_context(
- requester.user, room_id, event_id, limit, event_filter
+ requester, room_id, event_id, limit, event_filter
)
if not results:
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..4fe712b30c 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,24 @@
import logging
from functools import wraps
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -33,7 +44,7 @@ def _validate_group_id(f):
"""
@wraps(f)
- def wrapper(self, request, group_id, *args, **kwargs):
+ def wrapper(self, request: Request, group_id: str, *args, **kwargs):
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@@ -48,14 +59,14 @@ class GroupServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -66,11 +77,15 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request, group_id):
+ async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ content, ("name", "avatar_url", "short_description", "long_description")
+ )
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -84,14 +99,14 @@ class GroupSummaryServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -116,18 +131,21 @@ class GroupSummaryRoomsCatServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
@@ -139,10 +157,13 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -158,14 +179,16 @@ class GroupCategoryServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, category_id):
+ async def on_GET(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -176,11 +199,14 @@ class GroupCategoryServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
@@ -188,10 +214,13 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -205,14 +234,14 @@ class GroupCategoriesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -229,14 +258,16 @@ class GroupRoleServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, role_id):
+ async def on_GET(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -247,11 +278,14 @@ class GroupRoleServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
@@ -259,10 +293,13 @@ class GroupRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -276,14 +313,14 @@ class GroupRolesServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -308,18 +345,21 @@ class GroupSummaryUsersRoleServlet(RestServlet):
"/users/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id, user_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
@@ -331,10 +371,13 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id, user_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str, user_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -348,14 +391,14 @@ class GroupRoomServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -372,14 +415,14 @@ class GroupUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -396,14 +439,14 @@ class GroupInvitedUsersServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -420,18 +463,19 @@ class GroupSettingJoinPolicyServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -445,14 +489,14 @@ class GroupCreateServlet(RestServlet):
PATTERNS = client_patterns("/create_group$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -461,6 +505,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -476,18 +521,21 @@ class GroupAdminRoomsServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
@@ -495,10 +543,13 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
@_validate_group_id
- async def on_DELETE(self, request, group_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -515,18 +566,21 @@ class GroupAdminRoomsConfigServlet(RestServlet):
"/config/(?P<config_key>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id, config_key):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str, config_key: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -542,7 +596,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -551,12 +605,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -572,18 +627,19 @@ class GroupAdminUsersKickServlet(RestServlet):
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -597,18 +653,19 @@ class GroupSelfLeaveServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -622,18 +679,19 @@ class GroupSelfJoinServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -647,18 +705,19 @@ class GroupSelfAcceptInviteServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(self.groups_handler, GroupsLocalHandler)
result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -672,14 +731,14 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -696,14 +755,14 @@ class PublicisedGroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -717,14 +776,14 @@ class PublicisedGroupsForUsersServlet(RestServlet):
PATTERNS = client_patterns("/publicised_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -741,13 +800,13 @@ class GroupsForUserServlet(RestServlet):
PATTERNS = client_patterns("/joined_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -756,7 +815,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5d461efb6a..6b39e27f4c 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -195,6 +195,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
@@ -297,6 +298,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5ac307a62d..3e4566464b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,7 +58,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
+_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_xml_encoding_match = re.compile(
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
@@ -299,24 +302,7 @@ class PreviewUrlResource(DirectServeJsonResource):
with open(media_info["filename"], "rb") as file:
body = file.read()
- encoding = None
-
- # Let's try and figure out if it has an encoding set in a meta tag.
- # Limit it to the first 1kb, since it ought to be in the meta tags
- # at the top.
- match = _charset_match.search(body[:1000])
-
- # If we find a match, it should take precedence over the
- # Content-Type header, so set it here.
- if match:
- encoding = match.group(1).decode("ascii")
-
- # If we don't find a match, we'll look at the HTTP Content-Type, and
- # if that doesn't exist, we'll fall back to UTF-8.
- if not encoding:
- content_match = _content_type_match.match(media_info["media_type"])
- encoding = content_match.group(1) if content_match else "utf-8"
-
+ encoding = get_html_media_encoding(body, media_info["media_type"])
og = decode_and_calc_og(body, media_info["uri"], encoding)
# pre-cache the image for posterity
@@ -688,6 +674,48 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
+def get_html_media_encoding(body: bytes, content_type: str) -> str:
+ """
+ Get the encoding of the body based on the (presumably) HTML body or media_type.
+
+ The precedence used for finding a character encoding is:
+
+ 1. meta tag with a charset declared.
+ 2. The XML document's character encoding attribute.
+ 3. The Content-Type header.
+ 4. Fallback to UTF-8.
+
+ Args:
+ body: The HTML document, as bytes.
+ content_type: The Content-Type header.
+
+ Returns:
+ The character encoding of the body, as a string.
+ """
+ # Limit searches to the first 1kb, since it ought to be at the top.
+ body_start = body[:1024]
+
+ # Let's try and figure out if it has an encoding set in a meta tag.
+ match = _charset_match.search(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+ # If we didn't find a match, see if it an XML document with an encoding.
+ match = _xml_encoding_match.match(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # If we don't find a match, we'll look at the HTTP Content-Type, and
+ # if that doesn't exist, we'll fall back to UTF-8.
+ content_match = _content_type_match.match(content_type)
+ if content_match:
+ return content_match.group(1)
+
+ return "utf-8"
+
+
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
@@ -724,6 +752,11 @@ def decode_and_calc_og(
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
# Attempt to parse the body. If this fails, log and return no metadata.
tree = etree.fromstring(body_attempt, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return {}
+
return _calc_og(tree, media_uri)
# Attempt to parse the body. If this fails, log and return no metadata.
diff --git a/synapse/server.py b/synapse/server.py
index 6ffb7e0fd9..91d59b755a 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -24,7 +24,17 @@
import abc
import functools
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ TypeVar,
+ Union,
+ cast,
+)
import twisted.internet.base
import twisted.internet.tcp
@@ -582,7 +592,9 @@ class HomeServer(metaclass=abc.ABCMeta):
return UserDirectoryHandler(self)
@cache_in_self
- def get_groups_local_handler(self):
+ def get_groups_local_handler(
+ self,
+ ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]:
if self.config.worker_app:
return GroupsLocalWorkerHandler(self)
else:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d2ba4bd2fc..ae4bf1a54f 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -158,8 +158,8 @@ class LoggingDatabaseConnection:
def commit(self) -> None:
self.conn.commit()
- def rollback(self, *args, **kwargs) -> None:
- self.conn.rollback(*args, **kwargs)
+ def rollback(self) -> None:
+ self.conn.rollback()
def __enter__(self) -> "Connection":
self.conn.__enter__()
@@ -244,12 +244,15 @@ class LoggingTransaction:
assert self.exception_callbacks is not None
self.exception_callbacks.append((callback, args, kwargs))
+ def fetchone(self) -> Optional[Tuple]:
+ return self.txn.fetchone()
+
+ def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
+ return self.txn.fetchmany(size=size)
+
def fetchall(self) -> List[Tuple]:
return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
-
def __iter__(self) -> Iterator[Tuple]:
return self.txn.__iter__()
@@ -754,6 +757,7 @@ class DatabasePool:
Returns:
A list of dicts where the key is the column header.
"""
+ assert cursor.description is not None, "cursor.description was None"
col_headers = [intern(str(column[0])) for column in cursor.description]
results = [dict(zip(col_headers, row)) for row in cursor]
return results
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 566ea19bae..28bb2eb662 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -619,9 +619,9 @@ def _get_or_create_schema_state(
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
- current_version = int(row[0]) if row else None
- if current_version:
+ if row is not None:
+ current_version = int(row[0])
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,),
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 9cadcba18f..17291c9d5e 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterable, Iterator, List, Optional, Tuple
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from typing_extensions import Protocol
@@ -20,23 +20,44 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
+_Parameters = Union[Sequence[Any], Mapping[str, Any]]
+
class Cursor(Protocol):
- def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
...
- def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
...
- def fetchall(self) -> List[Tuple]:
+ def fetchone(self) -> Optional[Tuple]:
+ ...
+
+ def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]:
...
- def fetchone(self) -> Tuple:
+ def fetchall(self) -> List[Tuple]:
...
@property
- def description(self) -> Any:
- return None
+ def description(
+ self,
+ ) -> Optional[
+ Sequence[
+ # Note that this is an approximate typing based on sqlite3 and other
+ # drivers, and may not be entirely accurate.
+ Tuple[
+ str,
+ Optional[Any],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ Optional[int],
+ ]
+ ]
+ ]:
+ ...
@property
def rowcount(self) -> int:
@@ -59,7 +80,7 @@ class Connection(Protocol):
def commit(self) -> None:
...
- def rollback(self, *args, **kwargs) -> None:
+ def rollback(self) -> None:
...
def __enter__(self) -> "Connection":
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 0ec4dc2918..e2b316a218 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
def get_next_id_txn(self, txn: Cursor) -> int:
txn.execute("SELECT nextval(?)", (self._sequence_name,))
- return txn.fetchone()[0]
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ return fetch_res[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
@@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute(
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
- last_value, is_called = txn.fetchone()
+ fetch_res = txn.fetchone()
+ assert fetch_res is not None
+ last_value, is_called = fetch_res
# If we have an associated stream check the stream_positions table.
max_in_stream_positions = None
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index f8038bf861..9ce7873ab5 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -25,7 +25,7 @@ from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
@@ -42,28 +42,31 @@ MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
rand = random.SystemRandom()
-def random_string(length):
+def random_string(length: int) -> str:
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
-def random_string_with_symbols(length):
+def random_string_with_symbols(length: int) -> str:
return "".join(rand.choice(_string_with_symbols) for _ in range(length))
-def is_ascii(s):
- if isinstance(s, bytes):
- try:
- s.decode("ascii").encode("ascii")
- except UnicodeDecodeError:
- return False
- except UnicodeEncodeError:
- return False
- return True
+def is_ascii(s: bytes) -> bool:
+ try:
+ s.decode("ascii").encode("ascii")
+ except UnicodeDecodeError:
+ return False
+ except UnicodeEncodeError:
+ return False
+ return True
-def assert_valid_client_secret(client_secret):
- """Validate that a given string matches the client_secret regex defined by the spec"""
- if client_secret_regex.match(client_secret) is None:
+def assert_valid_client_secret(client_secret: str) -> None:
+ """Validate that a given string matches the client_secret defined by the spec"""
+ if (
+ len(client_secret) <= 0
+ or len(client_secret) > 255
+ or CLIENT_SECRET_REGEX.match(client_secret) is None
+ ):
raise SynapseError(
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 4a5df293a4..e39d02602a 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -80,6 +80,7 @@ async def filter_events_for_client(
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
|