summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-01 16:35:24 -0500
committerGitHub <noreply@github.com>2023-02-01 21:35:24 +0000
commit1182ae50635db94d3c9c47990a0befcbf6306b62 (patch)
tree56bdbf809b884428af3f6fe28e18dc44472077ac
parentAttempt to delete more duplicate rows in receipts_linearized table. (#14915) (diff)
downloadsynapse-1182ae50635db94d3c9c47990a0befcbf6306b62.tar.xz
Add helper to parse an enum from query args & use it. (#14956)
The `parse_enum` helper pulls an enum value from the query string
(by delegating down to the parse_string helper with values generated
from the enum).

This is used to pull out "f" and "b" in most places and then we thread
the resulting Direction enum throughout more code.
-rw-r--r--changelog.d/14956.misc1
-rw-r--r--synapse/federation/federation_client.py15
-rw-r--r--synapse/federation/federation_server.py12
-rw-r--r--synapse/federation/transport/client.py8
-rw-r--r--synapse/federation/transport/server/federation.py7
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/receipts.py2
-rw-r--r--synapse/handlers/room.py9
-rw-r--r--synapse/http/servlet.py70
-rw-r--r--synapse/rest/admin/event_reports.py12
-rw-r--r--synapse/rest/admin/federation.py7
-rw-r--r--synapse/rest/admin/media.py21
-rw-r--r--synapse/rest/admin/rooms.py16
-rw-r--r--synapse/rest/admin/statistics.py11
-rw-r--r--synapse/rest/admin/users.py5
-rw-r--r--synapse/rest/client/relations.py3
-rw-r--r--synapse/rest/client/room.py5
-rw-r--r--synapse/storage/databases/main/__init__.py5
-rw-r--r--synapse/storage/databases/main/events_worker.py11
-rw-r--r--synapse/storage/databases/main/media_repository.py5
-rw-r--r--synapse/storage/databases/main/room.py9
-rw-r--r--synapse/storage/databases/main/stats.py6
-rw-r--r--synapse/storage/databases/main/transactions.py13
-rw-r--r--synapse/streams/config.py12
-rw-r--r--tests/rest/admin/test_event_reports.py5
25 files changed, 176 insertions, 96 deletions
diff --git a/changelog.d/14956.misc b/changelog.d/14956.misc
new file mode 100644
index 0000000000..9f5384e60e
--- /dev/null
+++ b/changelog.d/14956.misc
@@ -0,0 +1 @@
+Add missing type hints.
\ No newline at end of file
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index feb32e40e5..8493ffc2e5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -37,7 +37,7 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
 from synapse.api.errors import (
     CodeMessageException,
     Codes,
@@ -1680,7 +1680,12 @@ class FederationClient(FederationBase):
         return result
 
     async def timestamp_to_event(
-        self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
+        self,
+        *,
+        destinations: List[str],
+        room_id: str,
+        timestamp: int,
+        direction: Direction,
     ) -> Optional["TimestampToEventResponse"]:
         """
         Calls each remote federating server from `destinations` asking for their closest
@@ -1693,7 +1698,7 @@ class FederationClient(FederationBase):
             room_id: Room to fetch the event from
             timestamp: The point in time (inclusive) we should navigate from in
                 the given direction to find the closest event.
-            direction: ["f"|"b"] to indicate whether we should navigate forward
+            direction: indicates whether we should navigate forward
                 or backward from the given timestamp to find the closest event.
 
         Returns:
@@ -1738,7 +1743,7 @@ class FederationClient(FederationBase):
             return None
 
     async def _timestamp_to_event_from_destination(
-        self, destination: str, room_id: str, timestamp: int, direction: str
+        self, destination: str, room_id: str, timestamp: int, direction: Direction
     ) -> "TimestampToEventResponse":
         """
         Calls a remote federating server at `destination` asking for their
@@ -1751,7 +1756,7 @@ class FederationClient(FederationBase):
             room_id: Room to fetch the event from
             timestamp: The point in time (inclusive) we should navigate from in
                 the given direction to find the closest event.
-            direction: ["f"|"b"] to indicate whether we should navigate forward
+            direction: indicates whether we should navigate forward
                 or backward from the given timestamp to find the closest event.
 
         Returns:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c9a6dfd1a4..8d36172484 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,13 @@ from prometheus_client import Counter, Gauge, Histogram
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
-from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+    Direction,
+    EduTypes,
+    EventContentFields,
+    EventTypes,
+    Membership,
+)
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -218,7 +224,7 @@ class FederationServer(FederationBase):
         return 200, res
 
     async def on_timestamp_to_event_request(
-        self, origin: str, room_id: str, timestamp: int, direction: str
+        self, origin: str, room_id: str, timestamp: int, direction: Direction
     ) -> Tuple[int, Dict[str, Any]]:
         """When we receive a federated `/timestamp_to_event` request,
         handle all of the logic for validating and fetching the event.
@@ -228,7 +234,7 @@ class FederationServer(FederationBase):
             room_id: Room to fetch the event from
             timestamp: The point in time (inclusive) we should navigate from in
                 the given direction to find the closest event.
-            direction: ["f"|"b"] to indicate whether we should navigate forward
+            direction: indicates whether we should navigate forward
                 or backward from the given timestamp to find the closest event.
 
         Returns:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 682666ab36..c05d598b70 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -32,7 +32,7 @@ from typing import (
 import attr
 import ijson
 
-from synapse.api.constants import Membership
+from synapse.api.constants import Direction, Membership
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.api.urls import (
@@ -169,7 +169,7 @@ class TransportLayerClient:
         )
 
     async def timestamp_to_event(
-        self, destination: str, room_id: str, timestamp: int, direction: str
+        self, destination: str, room_id: str, timestamp: int, direction: Direction
     ) -> Union[JsonDict, List]:
         """
         Calls a remote federating server at `destination` asking for their
@@ -180,7 +180,7 @@ class TransportLayerClient:
             room_id: Room to fetch the event from
             timestamp: The point in time (inclusive) we should navigate from in
                 the given direction to find the closest event.
-            direction: ["f"|"b"] to indicate whether we should navigate forward
+            direction: indicates whether we should navigate forward
                 or backward from the given timestamp to find the closest event.
 
         Returns:
@@ -194,7 +194,7 @@ class TransportLayerClient:
             room_id,
         )
 
-        args = {"ts": [str(timestamp)], "dir": [direction]}
+        args = {"ts": [str(timestamp)], "dir": [direction.value]}
 
         remote_response = await self.client.get_json(
             destination, path=path, args=args, try_trailing_slash_on_400=True
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 17c427387e..f7ca87adc4 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -26,7 +26,7 @@ from typing import (
 
 from typing_extensions import Literal
 
-from synapse.api.constants import EduTypes
+from synapse.api.constants import Direction, EduTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
@@ -234,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
         room_id: str,
     ) -> Tuple[int, JsonDict]:
         timestamp = parse_integer_from_args(query, "ts", required=True)
-        direction = parse_string_from_args(
-            query, "dir", default="f", allowed_values=["f", "b"], required=True
+        direction_str = parse_string_from_args(
+            query, "dir", allowed_values=["f", "b"], required=True
         )
+        direction = Direction(direction_str)
 
         return await self.handler.on_timestamp_to_event_request(
             origin, room_id, timestamp, direction
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index d500b21809..67e789eef7 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
 
-    def get_current_key(self, direction: str = "f") -> int:
+    def get_current_key(self) -> int:
         return self.store.get_max_account_data_stream_id()
 
     async def get_new_events(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6a4fed1156..04c61ae3dd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -315,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
         return events, to_key
 
-    def get_current_key(self, direction: str = "f") -> int:
+    def get_current_key(self) -> int:
         return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 60a6d9cf3c..7ba7c4ff07 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing_extensions import TypedDict
 
 import synapse.events.snapshot
 from synapse.api.constants import (
+    Direction,
     EventContentFields,
     EventTypes,
     GuestAccess,
@@ -1487,7 +1488,7 @@ class TimestampLookupHandler:
         requester: Requester,
         room_id: str,
         timestamp: int,
-        direction: str,
+        direction: Direction,
     ) -> Tuple[str, int]:
         """Find the closest event to the given timestamp in the given direction.
         If we can't find an event locally or the event we have locally is next to a gap,
@@ -1498,7 +1499,7 @@ class TimestampLookupHandler:
             room_id: Room to fetch the event from
             timestamp: The point in time (inclusive) we should navigate from in
                 the given direction to find the closest event.
-            direction: ["f"|"b"] to indicate whether we should navigate forward
+            direction: indicates whether we should navigate forward
                 or backward from the given timestamp to find the closest event.
 
         Returns:
@@ -1533,13 +1534,13 @@ class TimestampLookupHandler:
                 local_event_id, allow_none=False, allow_rejected=False
             )
 
-            if direction == "f":
+            if direction == Direction.FORWARDS:
                 # We only need to check for a backward gap if we're looking forwards
                 # to ensure there is nothing in between.
                 is_event_next_to_backward_gap = (
                     await self.store.is_event_next_to_backward_gap(local_event)
                 )
-            elif direction == "b":
+            elif direction == Direction.BACKWARDS:
                 # We only need to check for a forward gap if we're looking backwards
                 # to ensure there is nothing in between
                 is_event_next_to_forward_gap = (
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..0070bd2940 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 """ This module contains base REST classes for constructing REST servlets. """
+import enum
 import logging
 from http import HTTPStatus
 from typing import (
@@ -362,6 +363,7 @@ def parse_string(
     request: Request,
     name: str,
     *,
+    default: Optional[str] = None,
     required: bool = False,
     allowed_values: Optional[Iterable[str]] = None,
     encoding: str = "ascii",
@@ -413,6 +415,74 @@ def parse_string(
     )
 
 
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: EnumT,
+) -> EnumT:
+    ...
+
+
+@overload
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    *,
+    required: Literal[True],
+) -> EnumT:
+    ...
+
+
+def parse_enum(
+    request: Request,
+    name: str,
+    E: Type[EnumT],
+    default: Optional[EnumT] = None,
+    required: bool = False,
+) -> Optional[EnumT]:
+    """
+    Parse an enum parameter from the request query string.
+
+    Note that the enum *must only have string values*.
+
+    Args:
+        request: the twisted HTTP request.
+        name: the name of the query parameter.
+        E: the enum which represents valid values
+        default: enum value to use if the parameter is absent, defaults to None.
+        required: whether to raise a 400 SynapseError if the
+            parameter is absent, defaults to False.
+
+    Returns:
+        An enum value.
+
+    Raises:
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+    # Assert the enum values are strings.
+    assert all(
+        isinstance(e.value, str) for e in E
+    ), "parse_enum only works with string values"
+    str_value = parse_string(
+        request,
+        name,
+        default=default.value if default is not None else None,
+        required=required,
+        allowed_values=[e.value for e in E],
+    )
+    if str_value is None:
+        return None
+    return E(str_value)
+
+
 def _parse_string_value(
     value: bytes,
     allowed_values: Optional[Iterable[str]],
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 6d634eef70..a3beb74e2c 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -16,8 +16,9 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import Direction
 from synapse.api.errors import Codes, NotFoundError, SynapseError
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.types import JsonDict
@@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet):
 
         start = parse_integer(request, "from", default=0)
         limit = parse_integer(request, "limit", default=100)
-        direction = parse_string(request, "dir", default="b")
+        direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
         user_id = parse_string(request, "user_id")
         room_id = parse_string(request, "room_id")
 
@@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        if direction not in ("f", "b"):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown direction: %s" % (direction,),
-                errcode=Codes.INVALID_PARAM,
-            )
-
         event_reports, total = await self.store.get_event_reports_paginate(
             start, limit, direction, user_id, room_id
         )
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 023ed92144..e0ee55bd0e 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -15,9 +15,10 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import Direction
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.federation.transport.server import Authenticator
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.storage.databases.main.transactions import DestinationSortOrder
@@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet):
             allowed_values=[dest.value for dest in DestinationSortOrder],
         )
 
-        direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
         destinations, total = await self._store.get_destinations_paginate(
             start, limit, destination, order_by, direction
@@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
         rooms, total = await self._store.get_destination_rooms_paginate(
             destination, start, limit, direction
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 73470f09ae..0d072c42a7 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,9 +17,16 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import Direction
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
+from synapse.http.servlet import (
+    RestServlet,
+    parse_boolean,
+    parse_enum,
+    parse_integer,
+    parse_string,
+)
 from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import (
     admin_patterns,
@@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet):
         # to newest media is on top for backward compatibility.
         if b"order_by" not in request.args and b"dir" not in request.args:
             order_by = MediaSortOrder.CREATED_TS.value
-            direction = "b"
+            direction = Direction.BACKWARDS
         else:
             order_by = parse_string(
                 request,
@@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet):
                 default=MediaSortOrder.CREATED_TS.value,
                 allowed_values=[sort_order.value for sort_order in MediaSortOrder],
             )
-            direction = parse_string(
-                request, "dir", default="f", allowed_values=("f", "b")
+            direction = parse_enum(
+                request, "dir", Direction, default=Direction.FORWARDS
             )
 
         media, total = await self.store.get_local_media_by_user_paginate(
@@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet):
         # to newest media is on top for backward compatibility.
         if b"order_by" not in request.args and b"dir" not in request.args:
             order_by = MediaSortOrder.CREATED_TS.value
-            direction = "b"
+            direction = Direction.BACKWARDS
         else:
             order_by = parse_string(
                 request,
@@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet):
                 default=MediaSortOrder.CREATED_TS.value,
                 allowed_values=[sort_order.value for sort_order in MediaSortOrder],
             )
-            direction = parse_string(
-                request, "dir", default="f", allowed_values=("f", "b")
+            direction = parse_enum(
+                request, "dir", Direction, default=Direction.FORWARDS
             )
 
         media, _ = await self.store.get_local_media_by_user_paginate(
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index e957aa28ca..1d6e4982d7 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -16,13 +16,14 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 from urllib import parse as urlparse
 
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
 from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.http.servlet import (
     ResolveRoomIdMixin,
     RestServlet,
     assert_params_in_dict,
+    parse_enum,
     parse_integer,
     parse_json_object_from_request,
     parse_string,
@@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        direction = parse_string(request, "dir", default="f")
-        if direction not in ("f", "b"):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown direction: %s" % (direction,),
-                errcode=Codes.INVALID_PARAM,
-            )
-
-        reverse_order = True if direction == "b" else False
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
+        reverse_order = True if direction == Direction.BACKWARDS else False
 
         # Return list of rooms according to parameters
         rooms, total_rooms = await self.store.get_rooms_paginate(
@@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet):
         await assert_user_is_admin(self._auth, requester)
 
         timestamp = parse_integer(request, "ts", required=True)
-        direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
         (
             event_id,
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 3b142b8402..9c45f4650d 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -16,8 +16,9 @@ import logging
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Tuple
 
+from synapse.api.constants import Direction
 from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
 from synapse.storage.databases.main.stats import UserSortOrder
@@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
                 errcode=Codes.INVALID_PARAM,
             )
 
-        direction = parse_string(request, "dir", default="f")
-        if direction not in ("f", "b"):
-            raise SynapseError(
-                HTTPStatus.BAD_REQUEST,
-                "Unknown direction: %s" % (direction,),
-                errcode=Codes.INVALID_PARAM,
-            )
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
         users_media, total = await self.store.get_users_media_usage_paginate(
             start, limit, from_ts, until_ts, order_by, direction, search_term
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 0841b89c1a..b9dca8ef3a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -18,12 +18,13 @@ import secrets
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
-from synapse.api.constants import UserTypes
+from synapse.api.constants import Direction, UserTypes
 from synapse.api.errors import Codes, NotFoundError, SynapseError
 from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_boolean,
+    parse_enum,
     parse_integer,
     parse_json_object_from_request,
     parse_string,
@@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet):
             ),
         )
 
-        direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
         users, total = await self.store.get_users_paginate(
             start,
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 9dd59196d9..7456d6f507 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -16,6 +16,7 @@ import logging
 import re
 from typing import TYPE_CHECKING, Optional, Tuple
 
+from synapse.api.constants import Direction
 from synapse.handlers.relations import ThreadsListInclude
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         pagination_config = await PaginationConfig.from_request(
-            self._store, request, default_limit=5, default_dir="b"
+            self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
         )
 
         # The unstable version of this API returns an extra field for client
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 790614d721..d0db85cca7 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -26,7 +26,7 @@ from prometheus_client.core import Histogram
 from twisted.web.server import Request
 
 from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -44,6 +44,7 @@ from synapse.http.servlet import (
     RestServlet,
     assert_params_in_dict,
     parse_boolean,
+    parse_enum,
     parse_integer,
     parse_json_object_from_request,
     parse_string,
@@ -1297,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet):
         await self._auth.check_user_in_room_or_world_readable(room_id, requester)
 
         timestamp = parse_integer(request, "ts", required=True)
-        direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+        direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
 
         (
             event_id,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0e47592be3..837dc7646e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,7 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
+from synapse.api.constants import Direction
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import (
     DatabasePool,
@@ -167,7 +168,7 @@ class DataStore(
         guests: bool = True,
         deactivated: bool = False,
         order_by: str = UserSortOrder.NAME.value,
-        direction: str = "f",
+        direction: Direction = Direction.FORWARDS,
         approved: bool = True,
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users from
@@ -197,7 +198,7 @@ class DataStore(
             # Set ordering
             order_by_column = UserSortOrder(order_by).value
 
-            if direction == "b":
+            if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
                 order = "ASC"
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f42af34a2f..d7d08369ca 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -38,7 +38,7 @@ from typing_extensions import Literal
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import Direction, EventTypes
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
@@ -2240,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
 
     async def get_event_id_for_timestamp(
-        self, room_id: str, timestamp: int, direction: str
+        self, room_id: str, timestamp: int, direction: Direction
     ) -> Optional[str]:
         """Find the closest event to the given timestamp in the given direction.
 
@@ -2248,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
             room_id: Room to fetch the event from
             timestamp: The point in time (inclusive) we should navigate from in
                 the given direction to find the closest event.
-            direction: ["f"|"b"] to indicate whether we should navigate forward
+            direction: indicates whether we should navigate forward
                 or backward from the given timestamp to find the closest event.
 
         Returns:
             The closest event_id otherwise None if we can't find any event in
             the given direction.
         """
-        if direction == "b":
+        if direction == Direction.BACKWARDS:
             # Find closest event *before* a given timestamp. We use descending
             # (which gives values largest to smallest) because we want the
             # largest possible timestamp *before* the given timestamp.
@@ -2307,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
 
             return None
 
-        if direction not in ("f", "b"):
-            raise ValueError("Unknown direction: %s" % (direction,))
-
         return await self.db_pool.runInteraction(
             "get_event_id_for_timestamp_txn",
             get_event_id_for_timestamp_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 9b172a64d8..b202c5eb87 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,7 @@ from typing import (
     cast,
 )
 
+from synapse.api.constants import Direction
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         limit: int,
         user_id: str,
         order_by: str = MediaSortOrder.CREATED_TS.value,
-        direction: str = "f",
+        direction: Direction = Direction.FORWARDS,
     ) -> Tuple[List[Dict[str, Any]], int]:
         """Get a paginated list of metadata for a local piece of media
         which an user_id has uploaded
@@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
 
-            if direction == "b":
+            if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
                 order = "ASC"
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index fbbc018887..4ddb27f686 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -35,6 +35,7 @@ from typing import (
 import attr
 
 from synapse.api.constants import (
+    Direction,
     EventContentFields,
     EventTypes,
     JoinRules,
@@ -2204,7 +2205,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         self,
         start: int,
         limit: int,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         user_id: Optional[str] = None,
         room_id: Optional[str] = None,
     ) -> Tuple[List[Dict[str, Any]], int]:
@@ -2213,8 +2214,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         Args:
             start: event offset to begin the query from
             limit: number of rows to retrieve
-            direction: Whether to fetch the most recent first (`"b"`) or the
-                oldest first (`"f"`)
+            direction: Whether to fetch the most recent first (backwards) or the
+                oldest first (forwards)
             user_id: search for user_id. Ignored if user_id is None
             room_id: search for room_id. Ignored if room_id is None
         Returns:
@@ -2236,7 +2237,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
                 filters.append("er.room_id LIKE ?")
                 args.extend(["%" + room_id + "%"])
 
-            if direction == "b":
+            if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
                 order = "ASC"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0c1cbd540d..d7b7d0c3c9 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -22,7 +22,7 @@ from typing_extensions import Counter
 
 from twisted.internet.defer import DeferredLock
 
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
 from synapse.api.errors import StoreError
 from synapse.storage.database import (
     DatabasePool,
@@ -663,7 +663,7 @@ class StatsStore(StateDeltasStore):
         from_ts: Optional[int] = None,
         until_ts: Optional[int] = None,
         order_by: Optional[str] = UserSortOrder.USER_ID.value,
-        direction: Optional[str] = "f",
+        direction: Direction = Direction.FORWARDS,
         search_term: Optional[str] = None,
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of users and their uploaded local media
@@ -714,7 +714,7 @@ class StatsStore(StateDeltasStore):
                     500, "Incorrect value for order_by provided: %s" % order_by
                 )
 
-            if direction == "b":
+            if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
                 order = "ASC"
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index f8c6877ee8..6b33d809b6 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
 import attr
 from canonicaljson import encode_canonical_json
 
+from synapse.api.constants import Direction
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import db_to_json
 from synapse.storage.database import (
@@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         limit: int,
         destination: Optional[str] = None,
         order_by: str = DestinationSortOrder.DESTINATION.value,
-        direction: str = "f",
+        direction: Direction = Direction.FORWARDS,
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of destinations.
         This will return a json list of destinations and the
@@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         ) -> Tuple[List[JsonDict], int]:
             order_by_column = DestinationSortOrder(order_by).value
 
-            if direction == "b":
+            if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
                 order = "ASC"
@@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         )
 
     async def get_destination_rooms_paginate(
-        self, destination: str, start: int, limit: int, direction: str = "f"
+        self,
+        destination: str,
+        start: int,
+        limit: int,
+        direction: Direction = Direction.FORWARDS,
     ) -> Tuple[List[JsonDict], int]:
         """Function to retrieve a paginated list of destination's rooms.
         This will return a json list of rooms and the
@@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             txn: LoggingTransaction,
         ) -> Tuple[List[JsonDict], int]:
 
-            if direction == "b":
+            if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
                 order = "ASC"
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 5cb7875181..a044280410 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -18,7 +18,7 @@ import attr
 
 from synapse.api.constants import Direction
 from synapse.api.errors import SynapseError
-from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.servlet import parse_enum, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.storage.databases.main import DataStore
 from synapse.types import StreamToken
@@ -44,15 +44,9 @@ class PaginationConfig:
         store: "DataStore",
         request: SynapseRequest,
         default_limit: int,
-        default_dir: str = "f",
+        default_dir: Direction = Direction.FORWARDS,
     ) -> "PaginationConfig":
-        direction_str = parse_string(
-            request,
-            "dir",
-            default=default_dir,
-            allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
-        )
-        direction = Direction(direction_str)
+        direction = parse_enum(request, "dir", Direction, default=default_dir)
 
         from_tok_str = parse_string(request, "from")
         to_tok_str = parse_string(request, "to")
diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py
index 8a4e5c3f77..233eba3516 100644
--- a/tests/rest/admin/test_event_reports.py
+++ b/tests/rest/admin/test_event_reports.py
@@ -280,7 +280,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
-        self.assertEqual("Unknown direction: bar", channel.json_body["error"])
+        self.assertEqual(
+            "Query parameter 'dir' must be one of ['b', 'f']",
+            channel.json_body["error"],
+        )
 
     def test_limit_is_negative(self) -> None:
         """