diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 0821cd285f..0b0711c03c 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
"""
import logging
+from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import ShadowBanError, SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import (
AggregationPaginationToken,
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.stringutils import random_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs)
- def register(self, http_server):
+ def register(self, http_server: HttpServer) -> None:
http_server.register_paths(
"POST",
client_patterns(self.PATTERN + "$", releases=()),
@@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
self.__class__.__name__,
)
- def on_PUT(self, request, *args, **kwargs):
+ def on_PUT(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
- request, self.on_PUT_or_POST, request, *args, **kwargs
+ request,
+ self.on_PUT_or_POST,
+ request,
+ room_id,
+ parent_id,
+ relation_type,
+ event_type,
+ txn_id,
)
async def on_PUT_or_POST(
- self, request, room_id, parent_id, relation_type, event_type, txn_id=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ txn_id: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
@@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
@@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
async def on_GET(
- self, request, room_id, parent_id, relation_type=None, event_type=None
- ):
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to
# view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
releases=(),
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ parent_id: str,
+ relation_type: str,
+ event_type: str,
+ key: str,
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
@@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return 200, return_value
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server)
|