diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index feb32e40e5..8493ffc2e5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -37,7 +37,7 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
Codes,
@@ -1680,7 +1680,12 @@ class FederationClient(FederationBase):
return result
async def timestamp_to_event(
- self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
+ self,
+ *,
+ destinations: List[str],
+ room_id: str,
+ timestamp: int,
+ direction: Direction,
) -> Optional["TimestampToEventResponse"]:
"""
Calls each remote federating server from `destinations` asking for their closest
@@ -1693,7 +1698,7 @@ class FederationClient(FederationBase):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@@ -1738,7 +1743,7 @@ class FederationClient(FederationBase):
return None
async def _timestamp_to_event_from_destination(
- self, destination: str, room_id: str, timestamp: int, direction: str
+ self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> "TimestampToEventResponse":
"""
Calls a remote federating server at `destination` asking for their
@@ -1751,7 +1756,7 @@ class FederationClient(FederationBase):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c9a6dfd1a4..8d36172484 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,13 @@ from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
+from synapse.api.constants import (
+ Direction,
+ EduTypes,
+ EventContentFields,
+ EventTypes,
+ Membership,
+)
from synapse.api.errors import (
AuthError,
Codes,
@@ -218,7 +224,7 @@ class FederationServer(FederationBase):
return 200, res
async def on_timestamp_to_event_request(
- self, origin: str, room_id: str, timestamp: int, direction: str
+ self, origin: str, room_id: str, timestamp: int, direction: Direction
) -> Tuple[int, Dict[str, Any]]:
"""When we receive a federated `/timestamp_to_event` request,
handle all of the logic for validating and fetching the event.
@@ -228,7 +234,7 @@ class FederationServer(FederationBase):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 682666ab36..c05d598b70 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -32,7 +32,7 @@ from typing import (
import attr
import ijson
-from synapse.api.constants import Membership
+from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
@@ -169,7 +169,7 @@ class TransportLayerClient:
)
async def timestamp_to_event(
- self, destination: str, room_id: str, timestamp: int, direction: str
+ self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, List]:
"""
Calls a remote federating server at `destination` asking for their
@@ -180,7 +180,7 @@ class TransportLayerClient:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@@ -194,7 +194,7 @@ class TransportLayerClient:
room_id,
)
- args = {"ts": [str(timestamp)], "dir": [direction]}
+ args = {"ts": [str(timestamp)], "dir": [direction.value]}
remote_response = await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 17c427387e..f7ca87adc4 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -26,7 +26,7 @@ from typing import (
from typing_extensions import Literal
-from synapse.api.constants import EduTypes
+from synapse.api.constants import Direction, EduTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
@@ -234,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
room_id: str,
) -> Tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True)
- direction = parse_string_from_args(
- query, "dir", default="f", allowed_values=["f", "b"], required=True
+ direction_str = parse_string_from_args(
+ query, "dir", allowed_values=["f", "b"], required=True
)
+ direction = Direction(direction_str)
return await self.handler.on_timestamp_to_event_request(
origin, room_id, timestamp, direction
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index d500b21809..67e789eef7 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
- def get_current_key(self, direction: str = "f") -> int:
+ def get_current_key(self) -> int:
return self.store.get_max_account_data_stream_id()
async def get_new_events(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 6a4fed1156..04c61ae3dd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -315,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key
- def get_current_key(self, direction: str = "f") -> int:
+ def get_current_key(self) -> int:
return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 60a6d9cf3c..7ba7c4ff07 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,6 +27,7 @@ from typing_extensions import TypedDict
import synapse.events.snapshot
from synapse.api.constants import (
+ Direction,
EventContentFields,
EventTypes,
GuestAccess,
@@ -1487,7 +1488,7 @@ class TimestampLookupHandler:
requester: Requester,
room_id: str,
timestamp: int,
- direction: str,
+ direction: Direction,
) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap,
@@ -1498,7 +1499,7 @@ class TimestampLookupHandler:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@@ -1533,13 +1534,13 @@ class TimestampLookupHandler:
local_event_id, allow_none=False, allow_rejected=False
)
- if direction == "f":
+ if direction == Direction.FORWARDS:
# We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between.
is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event)
)
- elif direction == "b":
+ elif direction == Direction.BACKWARDS:
# We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between
is_event_next_to_forward_gap = (
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index dead02cd5c..0070bd2940 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -13,6 +13,7 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
+import enum
import logging
from http import HTTPStatus
from typing import (
@@ -362,6 +363,7 @@ def parse_string(
request: Request,
name: str,
*,
+ default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@@ -413,6 +415,74 @@ def parse_string(
)
+EnumT = TypeVar("EnumT", bound=enum.Enum)
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: EnumT,
+) -> EnumT:
+ ...
+
+
+@overload
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ *,
+ required: Literal[True],
+) -> EnumT:
+ ...
+
+
+def parse_enum(
+ request: Request,
+ name: str,
+ E: Type[EnumT],
+ default: Optional[EnumT] = None,
+ required: bool = False,
+) -> Optional[EnumT]:
+ """
+ Parse an enum parameter from the request query string.
+
+ Note that the enum *must only have string values*.
+
+ Args:
+ request: the twisted HTTP request.
+ name: the name of the query parameter.
+ E: the enum which represents valid values
+ default: enum value to use if the parameter is absent, defaults to None.
+ required: whether to raise a 400 SynapseError if the
+ parameter is absent, defaults to False.
+
+ Returns:
+ An enum value.
+
+ Raises:
+ SynapseError if the parameter is absent and required, or if the
+ parameter is present, must be one of a list of allowed values and
+ is not one of those allowed values.
+ """
+ # Assert the enum values are strings.
+ assert all(
+ isinstance(e.value, str) for e in E
+ ), "parse_enum only works with string values"
+ str_value = parse_string(
+ request,
+ name,
+ default=default.value if default is not None else None,
+ required=required,
+ allowed_values=[e.value for e in E],
+ )
+ if str_value is None:
+ return None
+ return E(str_value)
+
+
def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],
diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py
index 6d634eef70..a3beb74e2c 100644
--- a/synapse/rest/admin/event_reports.py
+++ b/synapse/rest/admin/event_reports.py
@@ -16,8 +16,9 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
@@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
- direction = parse_string(request, "dir", default="b")
+ direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
user_id = parse_string(request, "user_id")
room_id = parse_string(request, "room_id")
@@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- if direction not in ("f", "b"):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown direction: %s" % (direction,),
- errcode=Codes.INVALID_PARAM,
- )
-
event_reports, total = await self.store.get_event_reports_paginate(
start, limit, direction, user_id, room_id
)
diff --git a/synapse/rest/admin/federation.py b/synapse/rest/admin/federation.py
index 023ed92144..e0ee55bd0e 100644
--- a/synapse/rest/admin/federation.py
+++ b/synapse/rest/admin/federation.py
@@ -15,9 +15,10 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.transactions import DestinationSortOrder
@@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet):
allowed_values=[dest.value for dest in DestinationSortOrder],
)
- direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
@@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 73470f09ae..0d072c42a7 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -17,9 +17,16 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
+from synapse.http.servlet import (
+ RestServlet,
+ parse_boolean,
+ parse_enum,
+ parse_integer,
+ parse_string,
+)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
@@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
- direction = "b"
+ direction = Direction.BACKWARDS
else:
order_by = parse_string(
request,
@@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
- direction = parse_string(
- request, "dir", default="f", allowed_values=("f", "b")
+ direction = parse_enum(
+ request, "dir", Direction, default=Direction.FORWARDS
)
media, total = await self.store.get_local_media_by_user_paginate(
@@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
- direction = "b"
+ direction = Direction.BACKWARDS
else:
order_by = parse_string(
request,
@@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
- direction = parse_string(
- request, "dir", default="f", allowed_values=("f", "b")
+ direction = parse_enum(
+ request, "dir", Direction, default=Direction.FORWARDS
)
media, _ = await self.store.get_local_media_by_user_paginate(
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index e957aa28ca..1d6e4982d7 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -16,13 +16,14 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
assert_params_in_dict,
+ parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- direction = parse_string(request, "dir", default="f")
- if direction not in ("f", "b"):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown direction: %s" % (direction,),
- errcode=Codes.INVALID_PARAM,
- )
-
- reverse_order = True if direction == "b" else False
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
+ reverse_order = True if direction == Direction.BACKWARDS else False
# Return list of rooms according to parameters
rooms, total_rooms = await self.store.get_rooms_paginate(
@@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet):
await assert_user_is_admin(self._auth, requester)
timestamp = parse_integer(request, "ts", required=True)
- direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
(
event_id,
diff --git a/synapse/rest/admin/statistics.py b/synapse/rest/admin/statistics.py
index 3b142b8402..9c45f4650d 100644
--- a/synapse/rest/admin/statistics.py
+++ b/synapse/rest/admin/statistics.py
@@ -16,8 +16,9 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
+from synapse.api.constants import Direction
from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import RestServlet, parse_integer, parse_string
+from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.stats import UserSortOrder
@@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- direction = parse_string(request, "dir", default="f")
- if direction not in ("f", "b"):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Unknown direction: %s" % (direction,),
- errcode=Codes.INVALID_PARAM,
- )
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 0841b89c1a..b9dca8ef3a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -18,12 +18,13 @@ import secrets
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
-from synapse.api.constants import UserTypes
+from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
+ parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet):
),
)
- direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users, total = await self.store.get_users_paginate(
start,
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 9dd59196d9..7456d6f507 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -16,6 +16,7 @@ import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
- self._store, request, default_limit=5, default_dir="b"
+ self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
# The unstable version of this API returns an extra field for client
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 790614d721..d0db85cca7 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -26,7 +26,7 @@ from prometheus_client.core import Histogram
from twisted.web.server import Request
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -44,6 +44,7 @@ from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
+ parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -1297,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet):
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
timestamp = parse_integer(request, "ts", required=True)
- direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+ direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
(
event_id,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0e47592be3..837dc7646e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -17,6 +17,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import (
DatabasePool,
@@ -167,7 +168,7 @@ class DataStore(
guests: bool = True,
deactivated: bool = False,
order_by: str = UserSortOrder.NAME.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
@@ -197,7 +198,7 @@ class DataStore(
# Set ordering
order_by_column = UserSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f42af34a2f..d7d08369ca 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -38,7 +38,7 @@ from typing_extensions import Literal
from twisted.internet import defer
-from synapse.api.constants import EventTypes
+from synapse.api.constants import Direction, EventTypes
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@@ -2240,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_event_id_for_timestamp(
- self, room_id: str, timestamp: int, direction: str
+ self, room_id: str, timestamp: int, direction: Direction
) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction.
@@ -2248,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
- direction: ["f"|"b"] to indicate whether we should navigate forward
+ direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
- if direction == "b":
+ if direction == Direction.BACKWARDS:
# Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp.
@@ -2307,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
return None
- if direction not in ("f", "b"):
- raise ValueError("Unknown direction: %s" % (direction,))
-
return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 9b172a64d8..b202c5eb87 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -26,6 +26,7 @@ from typing import (
cast,
)
+from synapse.api.constants import Direction
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
limit: int,
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
# Set ordering
order_by_column = MediaSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index fbbc018887..4ddb27f686 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -35,6 +35,7 @@ from typing import (
import attr
from synapse.api.constants import (
+ Direction,
EventContentFields,
EventTypes,
JoinRules,
@@ -2204,7 +2205,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
start: int,
limit: int,
- direction: str = "b",
+ direction: Direction = Direction.BACKWARDS,
user_id: Optional[str] = None,
room_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
@@ -2213,8 +2214,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
start: event offset to begin the query from
limit: number of rows to retrieve
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`)
+ direction: Whether to fetch the most recent first (backwards) or the
+ oldest first (forwards)
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
Returns:
@@ -2236,7 +2237,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
filters.append("er.room_id LIKE ?")
args.extend(["%" + room_id + "%"])
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 0c1cbd540d..d7b7d0c3c9 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -22,7 +22,7 @@ from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import (
DatabasePool,
@@ -663,7 +663,7 @@ class StatsStore(StateDeltasStore):
from_ts: Optional[int] = None,
until_ts: Optional[int] = None,
order_by: Optional[str] = UserSortOrder.USER_ID.value,
- direction: Optional[str] = "f",
+ direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
@@ -714,7 +714,7 @@ class StatsStore(StateDeltasStore):
500, "Incorrect value for order_by provided: %s" % order_by
)
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index f8c6877ee8..6b33d809b6 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
+from synapse.api.constants import Direction
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
from synapse.storage.database import (
@@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: int,
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
- direction: str = "f",
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
@@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> Tuple[List[JsonDict], int]:
order_by_column = DestinationSortOrder(order_by).value
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
async def get_destination_rooms_paginate(
- self, destination: str, start: int, limit: int, direction: str = "f"
+ self,
+ destination: str,
+ start: int,
+ limit: int,
+ direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
@@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
- if direction == "b":
+ if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 5cb7875181..a044280410 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -18,7 +18,7 @@ import attr
from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
-from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.servlet import parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken
@@ -44,15 +44,9 @@ class PaginationConfig:
store: "DataStore",
request: SynapseRequest,
default_limit: int,
- default_dir: str = "f",
+ default_dir: Direction = Direction.FORWARDS,
) -> "PaginationConfig":
- direction_str = parse_string(
- request,
- "dir",
- default=default_dir,
- allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
- )
- direction = Direction(direction_str)
+ direction = parse_enum(request, "dir", Direction, default=default_dir)
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
|