diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index c5c54564be..9b0c546505 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,11 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
+from twisted.web.server import Request
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -30,6 +32,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
@@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
- def on_PUT(self, request, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
info, _ = await self._room_creation_handler.create_room(
@@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
return 200, info
- def get_room_config(self, request):
+ def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__,
)
- def on_GET_no_state_key(self, request, room_id, event_type):
+ def on_GET_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_GET(request, room_id, event_type, "")
- def on_PUT_no_state_key(self, request, room_id, event_type):
+ def on_PUT_no_state_key(
+ self, request: SynapseRequest, room_id: str, event_type: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
- async def on_GET(self, request, room_id, event_type, state_key):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
@@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
- async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ # Format must be event or content, per the parse_string call above.
+ raise RuntimeError(f"Unknown format: {format:r}.")
+
+ async def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ state_key: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if txn_id:
@@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- async def on_POST(self, request, room_id, event_type, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- event_dict = {
+ event_dict: JsonDict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
+ # Twisted will have processed the args by now.
+ assert request.args is not None
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
@@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_GET(self, request, room_id, event_type, txn_id):
+ def on_GET(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Tuple[int, str]:
return 200, "Not implemented"
- def on_PUT(self, request, room_id, event_type, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id]
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
- ):
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
- def on_PUT(self, request, room_identifier, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_identifier: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
try:
@@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
return 200, data
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server")
@@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
# TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
@@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
users_with_profile = await self.message_handler.get_joined_members(
@@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
+ # Twisted will have processed the args by now.
+ assert request.args is not None
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
@@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
events = await self.message_handler.get_state_events(
@@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
@@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
event = await self.event_handler.get_event(
@@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
- event = await self._event_serializer.serialize_event(event, time_now)
- return 200, event
+ event_dict = await self._event_serializer.serialize_event(event, time_now)
+ return 200, event_dict
- return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet):
@@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id, event_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str, event_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
- def on_PUT(self, request, room_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/[invite|join|leave]
PATTERNS = (
"/rooms/(?P<room_id>[^/]*)/"
@@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ membership_action: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
@@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
- def _has_3pid_invite_keys(self, content):
+ def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
- def on_PUT(self, request, room_id, membership_action, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(self, request, room_id, event_id, txn_id=None):
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_id: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_PUT(self, request, room_id, event_id, txn_id):
+ def on_PUT(
+ self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
+ ) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
hs.config.worker.writers.typing == hs.get_instance_name()
)
- async def on_PUT(self, request, room_id, user_id):
+ async def on_PUT(
+ self, request: SynapseRequest, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer:
@@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
alias_list = await self.directory_handler.get_aliases_for_room(
@@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
-def register_txn_path(servlet, regex_string, http_server, with_get=False):
+def register_txn_path(
+ servlet: RestServlet,
+ regex_string: str,
+ http_server: HttpServer,
+ with_get: bool = False,
+) -> None:
"""Registers a transaction-based path.
This registers two paths:
@@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
POST regex_string
Args:
- regex_string (str): The regex string to register. Must NOT have a
- trailing $ as this string will be appended to.
- http_server : The http_server to register paths with.
+ regex_string: The regex string to register. Must NOT have a
+ trailing $ as this string will be appended to.
+ http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
+ on_POST = getattr(servlet, "on_POST", None)
+ on_PUT = getattr(servlet, "on_PUT", None)
+ if on_POST is None or on_PUT is None:
+ raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
- servlet.on_POST,
+ on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_PUT,
+ on_PUT,
servlet.__class__.__name__,
)
+ on_GET = getattr(servlet, "on_GET", None)
if with_get:
+ if on_GET is None:
+ raise RuntimeError(
+ "register_txn_path called with with_get = True, but no on_GET method exists"
+ )
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_GET,
+ on_GET,
servlet.__class__.__name__,
)
@@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
-def register_servlets(hs: "HomeServer", http_server, is_worker=False):
+def register_servlets(
+ hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
+) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomForgetRestServlet(hs).register(http_server)
-def register_deprecated_servlets(hs, http_server):
+def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomInitialSyncRestServlet(hs).register(http_server)
|