summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-01-27 07:27:55 -0500
committerGitHub <noreply@github.com>2023-01-27 07:27:55 -0500
commit265735db9d7b0698a511fc9389db4d6f104f1aa8 (patch)
treeca495a3241d62ffa8da7cb29df90c467497ff8ed /synapse/storage/databases/main/stream.py
parentAdd missing type hints in tests (#14879) (diff)
downloadsynapse-265735db9d7b0698a511fc9389db4d6f104f1aa8.tar.xz
Use an enum for direction. (#14927)
For better type safety we  use an enum instead of strings to
configure direction (backwards or forwards).
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py59
1 files changed, 31 insertions, 28 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 8977bf33e7..818c46182e 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -55,6 +55,7 @@ from typing_extensions import Literal
 
 from twisted.internet import defer
 
+from synapse.api.constants import Direction
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
 _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
-
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
@@ -104,7 +104,7 @@ class _EventsAround:
 
 
 def generate_pagination_where_clause(
-    direction: str,
+    direction: Direction,
     column_names: Tuple[str, str],
     from_token: Optional[Tuple[Optional[int], int]],
     to_token: Optional[Tuple[Optional[int], int]],
@@ -130,27 +130,26 @@ def generate_pagination_where_clause(
           token, but include those that match the to token.
 
     Args:
-        direction: Whether we're paginating backwards("b") or forwards ("f").
+        direction: Whether we're paginating backwards or forwards.
         column_names: The column names to bound. Must *not* be user defined as
             these get inserted directly into the SQL statement without escapes.
         from_token: The start point for the pagination. This is an exclusive
-            minimum bound if direction is "f", and an inclusive maximum bound if
-            direction is "b".
+            minimum bound if direction is forwards, and an inclusive maximum bound if
+            direction is backwards.
         to_token: The endpoint point for the pagination. This is an inclusive
-            maximum bound if direction is "f", and an exclusive minimum bound if
-            direction is "b".
+            maximum bound if direction is forwards, and an exclusive minimum bound if
+            direction is backwards.
         engine: The database engine to generate the clauses for
 
     Returns:
         The sql expression
     """
-    assert direction in ("b", "f")
 
     where_clause = []
     if from_token:
         where_clause.append(
             _make_generic_sql_bound(
-                bound=">=" if direction == "b" else "<",
+                bound=">=" if direction == Direction.BACKWARDS else "<",
                 column_names=column_names,
                 values=from_token,
                 engine=engine,
@@ -160,7 +159,7 @@ def generate_pagination_where_clause(
     if to_token:
         where_clause.append(
             _make_generic_sql_bound(
-                bound="<" if direction == "b" else ">=",
+                bound="<" if direction == Direction.BACKWARDS else ">=",
                 column_names=column_names,
                 values=to_token,
                 engine=engine,
@@ -171,7 +170,7 @@ def generate_pagination_where_clause(
 
 
 def generate_pagination_bounds(
-    direction: str,
+    direction: Direction,
     from_token: Optional[RoomStreamToken],
     to_token: Optional[RoomStreamToken],
 ) -> Tuple[
@@ -181,7 +180,7 @@ def generate_pagination_bounds(
     Generate a start and end point for this page of events.
 
     Args:
-        direction: Whether pagination is going forwards or backwards. One of "f" or "b".
+        direction: Whether pagination is going forwards or backwards.
         from_token: The token to start pagination at, or None to start at the first value.
         to_token: The token to end pagination at, or None to not limit the end point.
 
@@ -201,7 +200,7 @@ def generate_pagination_bounds(
     # Tokens really represent positions between elements, but we use
     # the convention of pointing to the event before the gap. Hence
     # we have a bit of asymmetry when it comes to equalities.
-    if direction == "b":
+    if direction == Direction.BACKWARDS:
         order = "DESC"
     else:
         order = "ASC"
@@ -215,7 +214,7 @@ def generate_pagination_bounds(
     if from_token:
         if from_token.topological is not None:
             from_bound = from_token.as_historical_tuple()
-        elif direction == "b":
+        elif direction == Direction.BACKWARDS:
             from_bound = (
                 None,
                 from_token.get_max_stream_pos(),
@@ -230,7 +229,7 @@ def generate_pagination_bounds(
     if to_token:
         if to_token.topological is not None:
             to_bound = to_token.as_historical_tuple()
-        elif direction == "b":
+        elif direction == Direction.BACKWARDS:
             to_bound = (
                 None,
                 to_token.stream,
@@ -245,20 +244,20 @@ def generate_pagination_bounds(
 
 
 def generate_next_token(
-    direction: str, last_topo_ordering: int, last_stream_ordering: int
+    direction: Direction, last_topo_ordering: int, last_stream_ordering: int
 ) -> RoomStreamToken:
     """
     Generate the next room stream token based on the currently returned data.
 
     Args:
-        direction: Whether pagination is going forwards or backwards. One of "f" or "b".
+        direction: Whether pagination is going forwards or backwards.
         last_topo_ordering: The last topological ordering being returned.
         last_stream_ordering: The last stream ordering being returned.
 
     Returns:
         A new RoomStreamToken to return to the client.
     """
-    if direction == "b":
+    if direction == Direction.BACKWARDS:
         # Tokens are positions between events.
         # This token points *after* the last event in the chunk.
         # We need it to point to the event before it in the chunk
@@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn,
             room_id,
             before_token,
-            direction="b",
+            direction=Direction.BACKWARDS,
             limit=before_limit,
             event_filter=event_filter,
         )
@@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn,
             room_id,
             after_token,
-            direction="f",
+            direction=Direction.FORWARDS,
             limit=after_limit,
             event_filter=event_filter,
         )
@@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         limit: int = -1,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id
             from_token: The token used to stream from
             to_token: A token which if given limits the results to only those before
-            direction: Either 'b' or 'f' to indicate whether we are paginating
-                forwards or backwards from `from_key`.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: The maximum number of events to return.
             event_filter: If provided filters the events to
                 those that match the filter.
@@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             _EventDictReturn(event_id, topological_ordering, stream_ordering)
             for event_id, instance_name, topological_ordering, stream_ordering in txn
             if _filter_results(
-                lower_token=to_token if direction == "b" else from_token,
-                upper_token=from_token if direction == "b" else to_token,
+                lower_token=to_token
+                if direction == Direction.BACKWARDS
+                else from_token,
+                upper_token=from_token
+                if direction == Direction.BACKWARDS
+                else to_token,
                 instance_name=instance_name,
                 topological_ordering=topological_ordering,
                 stream_ordering=stream_ordering,
@@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         room_id: str,
         from_key: RoomStreamToken,
         to_key: Optional[RoomStreamToken] = None,
-        direction: str = "b",
+        direction: Direction = Direction.BACKWARDS,
         limit: int = -1,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
@@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             room_id
             from_key: The token used to stream from
             to_key: A token which if given limits the results to only those before
-            direction: Either 'b' or 'f' to indicate whether we are paginating
-                forwards or backwards from `from_key`.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: The maximum number of events to return.
             event_filter: If provided filters the events to those that match the filter.