summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-01-27 07:27:55 -0500
committerGitHub <noreply@github.com>2023-01-27 07:27:55 -0500
commit265735db9d7b0698a511fc9389db4d6f104f1aa8 (patch)
treeca495a3241d62ffa8da7cb29df90c467497ff8ed /synapse
parentAdd missing type hints in tests (#14879) (diff)
downloadsynapse-265735db9d7b0698a511fc9389db4d6f104f1aa8.tar.xz
Use an enum for direction. (#14927)
For better type safety we  use an enum instead of strings to
configure direction (backwards or forwards).
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py7
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/initial_sync.py16
-rw-r--r--synapse/handlers/pagination.py6
-rw-r--r--synapse/handlers/relations.py8
-rw-r--r--synapse/storage/databases/main/relations.py8
-rw-r--r--synapse/storage/databases/main/stream.py59
-rw-r--r--synapse/streams/config.py11
8 files changed, 75 insertions, 44 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6432d32d83..6f9239d21c 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -17,6 +17,8 @@
 
 """Contains constants from the specification."""
 
+import enum
+
 from typing_extensions import Final
 
 # the max size of a (canonical-json-encoded) event
@@ -290,3 +292,8 @@ class ApprovalNoticeMedium:
 
     NONE = "org.matrix.msc3866.none"
     EMAIL = "org.matrix.msc3866.email"
+
+
+class Direction(enum.Enum):
+    BACKWARDS = "b"
+    FORWARDS = "f"
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5bf8e86387..c81ea34758 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -16,7 +16,7 @@ import abc
 import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
 
-from synapse.api.constants import Membership
+from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
 from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
 from synapse.visibility import filter_events_for_client
@@ -197,7 +197,7 @@ class AdminHandler:
             # efficient method perhaps but it does guarantee we get everything.
             while True:
                 events, _ = await self.store.paginate_room_events(
-                    room_id, from_key, to_key, limit=100, direction="f"
+                    room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
                 )
                 if not events:
                     break
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 8c2260ad7d..191529bd8e 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -15,7 +15,13 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
-from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
+from synapse.api.constants import (
+    AccountDataTypes,
+    Direction,
+    EduTypes,
+    EventTypes,
+    Membership,
+)
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.utils import SerializeEventConfig
@@ -57,7 +63,13 @@ class InitialSyncHandler:
         self.validator = EventValidator()
         self.snapshot_cache: ResponseCache[
             Tuple[
-                str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
+                str,
+                Optional[StreamToken],
+                Optional[StreamToken],
+                Direction,
+                int,
+                bool,
+                bool,
             ]
         ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1fe6567185..ceefa16b49 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -19,7 +19,7 @@ import attr
 
 from twisted.python.failure import Failure
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import Direction, EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events.utils import SerializeEventConfig
@@ -448,7 +448,7 @@ class PaginationHandler:
 
         if pagin_config.from_token:
             from_token = pagin_config.from_token
-        elif pagin_config.direction == "f":
+        elif pagin_config.direction == Direction.FORWARDS:
             from_token = (
                 await self.hs.get_event_sources().get_start_token_for_pagination(
                     room_id
@@ -476,7 +476,7 @@ class PaginationHandler:
                     room_id, requester, allow_departed_users=True
                 )
 
-            if pagin_config.direction == "b":
+            if pagin_config.direction == Direction.BACKWARDS:
                 # if we're going backwards, we might need to backfill. This
                 # requires that we have a topo token.
                 if room_token.topological:
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index e96f9999a8..0fb15391e0 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O
 
 import attr
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import Direction, EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -413,7 +413,11 @@ class RelationsHandler:
 
                 # Attempt to find another event to use as the latest event.
                 potential_events, _ = await self._main_store.get_relations_for_event(
-                    event_id, event, room_id, RelationTypes.THREAD, direction="f"
+                    event_id,
+                    event,
+                    room_id,
+                    RelationTypes.THREAD,
+                    direction=Direction.FORWARDS,
                 )
 
                 # Filter out ignored users.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index be2242b6ac..0018d6f7ab 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -30,7 +30,7 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
@@ -168,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
         limit: int = 5,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
     ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@@ -181,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
             limit: Only fetch the most recent `limit` events.
-            direction: Whether to fetch the most recent first (`"b"`) or the
-                oldest first (`"f"`).
+            direction: Whether to fetch the most recent first (backwards) or the
+                oldest first (forwards).
             from_token: Fetch rows from the given token, or from the start if None.
             to_token: Fetch rows up to the given token, or up to the end if None.
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8977bf33e7..818c46182e 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -55,6 +55,7 @@ from typing_extensions import Literal
 
 from twisted.internet import defer
 
+from synapse.api.constants import Direction
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
 _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
-
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
@@ -104,7 +104,7 @@ class _EventsAround:
 
 
 def generate_pagination_where_clause(
-    direction: str,
+    direction: Direction,
     column_names: Tuple[str, str],
     from_token: Optional[Tuple[Optional[int], int]],
     to_token: Optional[Tuple[Optional[int], int]],
@@ -130,27 +130,26 @@ def generate_pagination_where_clause(
           token, but include those that match the to token.
 
     Args:
-        direction: Whether we're paginating backwards("b") or forwards ("f").
+        direction: Whether we're paginating backwards or forwards.
         column_names: The column names to bound. Must *not* be user defined as
             these get inserted directly into the SQL statement without escapes.
         from_token: The start point for the pagination. This is an exclusive
-            minimum bound if direction is "f", and an inclusive maximum bound if
-            direction is "b".
+            minimum bound if direction is forwards, and an inclusive maximum bound if
+            direction is backwards.
         to_token: The endpoint point for the pagination. This is an inclusive
-            maximum bound if direction is "f", and an exclusive minimum bound if
-            direction is "b".
+            maximum bound if direction is forwards, and an exclusive minimum bound if
+            direction is backwards.
         engine: The database engine to generate the clauses for
 
     Returns:
         The sql expression
     """
-    assert direction in ("b", "f")
 
     where_clause = []
     if from_token:
         where_clause.append(
             _make_generic_sql_bound(
-                bound=">=" if direction == "b" else "<",
+                bound=">=" if direction == Direction.BACKWARDS else "<",
                 column_names=column_names,
                 values=from_token,
                 engine=engine,
@@ -160,7 +159,7 @@ def generate_pagination_where_clause(
     if to_token:
         where_clause.append(
             _make_generic_sql_bound(
-                bound="<" if direction == "b" else ">=",
+                bound="<" if direction == Direction.BACKWARDS else ">=",
                 column_names=column_names,
                 values=to_token,
                 engine=engine,
@@ -171,7 +170,7 @@ def generate_pagination_where_clause(
 
 
 def generate_pagination_bounds(
-    direction: str,
+    direction: Direction,
     from_token: Optional[RoomStreamToken],
     to_token: Optional[RoomStreamToken],
 ) -> Tuple[
@@ -181,7 +180,7 @@ def generate_pagination_bounds(
     Generate a start and end point for this page of events.
 
     Args:
-        direction: Whether pagination is going forwards or backwards. One of "f" or "b".
+        direction: Whether pagination is going forwards or backwards.
         from_token: The token to start pagination at, or None to start at the first value.
         to_token: The token to end pagination at, or None to not limit the end point.
 
@@ -201,7 +200,7 @@ def generate_pagination_bounds(
     # Tokens really represent positions between elements, but we use
     # the convention of pointing to the event before the gap. Hence
     # we have a bit of asymmetry when it comes to equalities.
-    if direction == "b":
+    if direction == Direction.BACKWARDS:
         order = "DESC"
     else:
         order = "ASC"
@@ -215,7 +214,7 @@ def generate_pagination_bounds(
     if from_token:
         if from_token.topological is not None:
             from_bound = from_token.as_historical_tuple()
-        elif direction == "b":
+        elif direction == Direction.BACKWARDS:
             from_bound = (
                 None,
                 from_token.get_max_stream_pos(),
@@ -230,7 +229,7 @@ def generate_pagination_bounds(
     if to_token:
         if to_token.topological is not None:
             to_bound = to_token.as_historical_tuple()
-        elif direction == "b":
+        elif direction == Direction.BACKWARDS:
             to_bound = (
                 None,
                 to_token.stream,
@@ -245,20 +244,20 @@ def generate_pagination_bounds(
 
 
 def generate_next_token(
-    direction: str, last_topo_ordering: int, last_stream_ordering: int
+    direction: Direction, last_topo_ordering: int, last_stream_ordering: int
 ) -> RoomStreamToken:
     """
     Generate the next room stream token based on the currently returned data.
 
     Args:
-        direction: Whether pagination is going forwards or backwards. One of "f" or "b".
+        direction: Whether pagination is going forwards or backwards.
         last_topo_ordering: The last topological ordering being returned.
         last_stream_ordering: The last stream ordering being returned.
 
     Returns:
         A new RoomStreamToken to return to the client.
     """
-    if direction == "b":
+    if direction == Direction.BACKWARDS:
         # Tokens are positions between events.
         # This token points *after* the last event in the chunk.
         # We need it to point to the event before it in the chunk
@@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn,
             room_id,
             before_token,
-            direction="b",
+            direction=Direction.BACKWARDS,
             limit=before_limit,
             event_filter=event_filter,
         )
@@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn,
             room_id,
             after_token,
-            direction="f",
+            direction=Direction.FORWARDS,
             limit=after_limit,
             event_filter=event_filter,
         )
@@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         limit: int = -1,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id
             from_token: The token used to stream from
             to_token: A token which if given limits the results to only those before
-            direction: Either 'b' or 'f' to indicate whether we are paginating
-                forwards or backwards from `from_key`.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: The maximum number of events to return.
             event_filter: If provided filters the events to
                 those that match the filter.
@@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             _EventDictReturn(event_id, topological_ordering, stream_ordering)
             for event_id, instance_name, topological_ordering, stream_ordering in txn
             if _filter_results(
-                lower_token=to_token if direction == "b" else from_token,
-                upper_token=from_token if direction == "b" else to_token,
+                lower_token=to_token
+                if direction == Direction.BACKWARDS
+                else from_token,
+                upper_token=from_token
+                if direction == Direction.BACKWARDS
+                else to_token,
                 instance_name=instance_name,
                 topological_ordering=topological_ordering,
                 stream_ordering=stream_ordering,
@@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         room_id: str,
         from_key: RoomStreamToken,
         to_key: Optional[RoomStreamToken] = None,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         limit: int = -1,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id
             from_key: The token used to stream from
             to_key: A token which if given limits the results to only those before
-            direction: Either 'b' or 'f' to indicate whether we are paginating
-                forwards or backwards from `from_key`.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: The maximum number of events to return.
             event_filter: If provided filters the events to those that match the filter.
 
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 6df2de919c..5cb7875181 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -16,6 +16,7 @@ from typing import Optional
 
 import attr
 
+from synapse.api.constants import Direction
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import parse_integer, parse_string
 from synapse.http.site import SynapseRequest
@@ -34,7 +35,7 @@ class PaginationConfig:
 
     from_token: Optional[StreamToken]
     to_token: Optional[StreamToken]
-    direction: str
+    direction: Direction
     limit: int
 
     @classmethod
@@ -45,9 +46,13 @@ class PaginationConfig:
         default_limit: int,
         default_dir: str = "f",
     ) -> "PaginationConfig":
-        direction = parse_string(
-            request, "dir", default=default_dir, allowed_values=["f", "b"]
+        direction_str = parse_string(
+            request,
+            "dir",
+            default=default_dir,
+            allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
         )
+        direction = Direction(direction_str)
 
         from_tok_str = parse_string(request, "from")
         to_tok_str = parse_string(request, "to")