summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-07-21 09:47:56 -0400
committerGitHub <noreply@github.com>2021-07-21 09:47:56 -0400
commit5db118626bebb9ce3913758282787d47cd8f375e (patch)
tree9515e33f8e3a319f2e76ca88094654d946304979
parentSwitch to `chunk` events so we can auth via power_levels (MSC2716) (#10432) (diff)
downloadsynapse-5db118626bebb9ce3913758282787d47cd8f375e.tar.xz
Add a return type to parse_string. (#10438)
And set the required attribute in a few places which will error if
a parameter is not provided.
-rw-r--r--changelog.d/10438.misc1
-rw-r--r--synapse/http/servlet.py38
-rw-r--r--synapse/rest/admin/users.py4
-rw-r--r--synapse/rest/client/v1/room.py8
-rw-r--r--synapse/rest/client/v2_alpha/keys.py2
-rw-r--r--synapse/rest/client/v2_alpha/relations.py42
-rw-r--r--synapse/rest/client/v2_alpha/sync.py2
-rw-r--r--synapse/rest/consent/consent_resource.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py10
-rw-r--r--synapse/storage/databases/main/__init__.py2
-rw-r--r--synapse/storage/databases/main/room.py2
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/streams/config.py16
13 files changed, 86 insertions, 45 deletions
diff --git a/changelog.d/10438.misc b/changelog.d/10438.misc
new file mode 100644
index 0000000000..a557578499
--- /dev/null
+++ b/changelog.d/10438.misc
@@ -0,0 +1 @@
+Improve servlet type hints.
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 04560fb589..cf45b6623b 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -172,6 +172,42 @@ def parse_bytes_from_args(
     return default
 
 
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    default: str,
+    *,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> str:
+    ...
+
+
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    *,
+    required: Literal[True],
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> str:
+    ...
+
+
+@overload
+def parse_string(
+    request: Request,
+    name: str,
+    *,
+    required: bool = False,
+    allowed_values: Optional[Iterable[str]] = None,
+    encoding: str = "ascii",
+) -> Optional[str]:
+    ...
+
+
 def parse_string(
     request: Request,
     name: str,
@@ -179,7 +215,7 @@ def parse_string(
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
-):
+) -> Optional[str]:
     """
     Parse a string parameter from the request query string.
 
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 589e47fa47..6736536172 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -90,8 +90,8 @@ class UsersRestServletV2(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        user_id = parse_string(request, "user_id", default=None)
-        name = parse_string(request, "name", default=None)
+        user_id = parse_string(request, "user_id")
+        name = parse_string(request, "name")
         guests = parse_boolean(request, "guests", default=True)
         deactivated = parse_boolean(request, "deactivated", default=False)
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index c95c5ae234..5d309a534c 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
         assert_params_in_dict(body, ["state_events_at_start", "events"])
 
         prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
-        chunk_id_from_query = parse_string(request, "chunk_id", default=None)
+        chunk_id_from_query = parse_string(request, "chunk_id")
 
         if prev_events_from_query is None:
             raise SynapseError(
@@ -735,7 +735,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
         self.auth = hs.get_auth()
 
     async def on_GET(self, request):
-        server = parse_string(request, "server", default=None)
+        server = parse_string(request, "server")
 
         try:
             await self.auth.get_user_by_req(request, allow_guest=True)
@@ -755,7 +755,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
                 raise e
 
         limit = parse_integer(request, "limit", 0)
-        since_token = parse_string(request, "since", None)
+        since_token = parse_string(request, "since")
 
         if limit == 0:
             # zero is a special value which corresponds to no limit.
@@ -789,7 +789,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
     async def on_POST(self, request):
         await self.auth.get_user_by_req(request, allow_guest=True)
 
-        server = parse_string(request, "server", default=None)
+        server = parse_string(request, "server")
         content = parse_json_object_from_request(request)
 
         limit: Optional[int] = int(content.get("limit", 100))
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 33cf8de186..d0d9d30d40 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet):
     async def on_GET(self, request):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        from_token_string = parse_string(request, "from")
+        from_token_string = parse_string(request, "from", required=True)
         set_tag("from", from_token_string)
 
         # We want to enforce they do pass us one, but we ignore it and return
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index c7da6759db..0821cd285f 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet):
         event = await self.event_handler.get_event(requester.user, room_id, parent_id)
 
         limit = parse_integer(request, "limit", default=5)
-        from_token = parse_string(request, "from")
-        to_token = parse_string(request, "to")
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
 
         if event.internal_metadata.is_redacted():
             # If the event is redacted, return an empty list of relations
             pagination_chunk = PaginationChunk(chunk=[])
         else:
             # Return the relations
-            if from_token:
-                from_token = RelationPaginationToken.from_string(from_token)
+            from_token = None
+            if from_token_str:
+                from_token = RelationPaginationToken.from_string(from_token_str)
 
-            if to_token:
-                to_token = RelationPaginationToken.from_string(to_token)
+            to_token = None
+            if to_token_str:
+                to_token = RelationPaginationToken.from_string(to_token_str)
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
@@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet):
             raise SynapseError(400, "Relation type must be 'annotation'")
 
         limit = parse_integer(request, "limit", default=5)
-        from_token = parse_string(request, "from")
-        to_token = parse_string(request, "to")
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
 
         if event.internal_metadata.is_redacted():
             # If the event is redacted, return an empty list of relations
             pagination_chunk = PaginationChunk(chunk=[])
         else:
             # Return the relations
-            if from_token:
-                from_token = AggregationPaginationToken.from_string(from_token)
+            from_token = None
+            if from_token_str:
+                from_token = AggregationPaginationToken.from_string(from_token_str)
 
-            if to_token:
-                to_token = AggregationPaginationToken.from_string(to_token)
+            to_token = None
+            if to_token_str:
+                to_token = AggregationPaginationToken.from_string(to_token_str)
 
             pagination_chunk = await self.store.get_aggregation_groups_for_event(
                 event_id=parent_id,
@@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
             raise SynapseError(400, "Relation type must be 'annotation'")
 
         limit = parse_integer(request, "limit", default=5)
-        from_token = parse_string(request, "from")
-        to_token = parse_string(request, "to")
+        from_token_str = parse_string(request, "from")
+        to_token_str = parse_string(request, "to")
 
-        if from_token:
-            from_token = RelationPaginationToken.from_string(from_token)
+        from_token = None
+        if from_token_str:
+            from_token = RelationPaginationToken.from_string(from_token_str)
 
-        if to_token:
-            to_token = RelationPaginationToken.from_string(to_token)
+        to_token = None
+        if to_token_str:
+            to_token = RelationPaginationToken.from_string(to_token_str)
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index ecbbcf3851..7bb4e6b8aa 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet):
             default="online",
             allowed_values=self.ALLOWED_PRESENCE,
         )
-        filter_id = parse_string(request, "filter", default=None)
+        filter_id = parse_string(request, "filter")
         full_state = parse_boolean(request, "full_state", default=False)
 
         logger.debug(
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4282e2b228..11f7320832 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource):
             request (twisted.web.http.Request):
         """
         version = parse_string(request, "v", default=self._default_consent_version)
-        username = parse_string(request, "u", required=False, default="")
+        username = parse_string(request, "u", default="")
         userhmac = None
         has_consented = False
         public_version = username == ""
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8e7fead3a2..172212ee3a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -186,15 +186,11 @@ class PreviewUrlResource(DirectServeJsonResource):
         respond_with_json(request, 200, {}, send_cors=True)
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
-        # This will always be set by the time Twisted calls us.
-        assert request.args is not None
-
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)
-        url = parse_string(request, "url")
-        if b"ts" in request.args:
-            ts = parse_integer(request, "ts")
-        else:
+        url = parse_string(request, "url", required=True)
+        ts = parse_integer(request, "ts")
+        if ts is None:
             ts = self.clock.time_msec()
 
         # XXX: we could move this into _do_preview if we wanted.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3fddea042..bacfbce4af 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -249,7 +249,7 @@ class DataStore(
         name: Optional[str] = None,
         guests: bool = True,
         deactivated: bool = False,
-        order_by: UserSortOrder = UserSortOrder.USER_ID.value,
+        order_by: str = UserSortOrder.USER_ID.value,
         direction: str = "f",
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users from
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6ddafe5434..443e5f3315 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore):
         self,
         start: int,
         limit: int,
-        order_by: RoomSortOrder,
+        order_by: str,
         reverse_order: bool,
         search_term: Optional[str],
     ) -> Tuple[List[Dict[str, Any]], int]:
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 59d67c255b..0f9aa54ca9 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -647,7 +647,7 @@ class StatsStore(StateDeltasStore):
         limit: int,
         from_ts: Optional[int] = None,
         until_ts: Optional[int] = None,
-        order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
+        order_by: Optional[str] = UserSortOrder.USER_ID.value,
         direction: Optional[str] = "f",
         search_term: Optional[str] = None,
     ) -> Tuple[List[JsonDict], Dict[str, int]]:
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 13d300588b..cf4005984b 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -47,20 +47,22 @@ class PaginationConfig:
     ) -> "PaginationConfig":
         direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
 
-        from_tok = parse_string(request, "from")
-        to_tok = parse_string(request, "to")
+        from_tok_str = parse_string(request, "from")
+        to_tok_str = parse_string(request, "to")
 
         try:
-            if from_tok == "END":
+            from_tok = None
+            if from_tok_str == "END":
                 from_tok = None  # For backwards compat.
-            elif from_tok:
-                from_tok = await StreamToken.from_string(store, from_tok)
+            elif from_tok_str:
+                from_tok = await StreamToken.from_string(store, from_tok_str)
         except Exception:
             raise SynapseError(400, "'from' parameter is invalid")
 
         try:
-            if to_tok:
-                to_tok = await StreamToken.from_string(store, to_tok)
+            to_tok = None
+            if to_tok_str:
+                to_tok = await StreamToken.from_string(store, to_tok_str)
         except Exception:
             raise SynapseError(400, "'to' parameter is invalid")