From 265735db9d7b0698a511fc9389db4d6f104f1aa8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 Jan 2023 07:27:55 -0500 Subject: Use an enum for direction. (#14927) For better type safety we use an enum instead of strings to configure direction (backwards or forwards). --- synapse/storage/databases/main/relations.py | 8 ++-- synapse/storage/databases/main/stream.py | 59 +++++++++++++++-------------- 2 files changed, 35 insertions(+), 32 deletions(-) (limited to 'synapse/storage/databases') diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be2242b6ac..0018d6f7ab 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -30,7 +30,7 @@ from typing import ( import attr -from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -168,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Optional[str] = None, event_type: Optional[str] = None, limit: int = 5, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: @@ -181,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore): relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. limit: Only fetch the most recent `limit` events. - direction: Whether to fetch the most recent first (`"b"`) or the - oldest first (`"f"`). + direction: Whether to fetch the most recent first (backwards) or the + oldest first (forwards). from_token: Fetch rows from the given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None. diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 8977bf33e7..818c46182e 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -55,6 +55,7 @@ from typing_extensions import Literal from twisted.internet import defer +from synapse.api.constants import Direction from synapse.api.filtering import Filter from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000 _STREAM_TOKEN = "stream" _TOPOLOGICAL_TOKEN = "topological" - # Used as return values for pagination APIs @attr.s(slots=True, frozen=True, auto_attribs=True) class _EventDictReturn: @@ -104,7 +104,7 @@ class _EventsAround: def generate_pagination_where_clause( - direction: str, + direction: Direction, column_names: Tuple[str, str], from_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]], @@ -130,27 +130,26 @@ def generate_pagination_where_clause( token, but include those that match the to token. Args: - direction: Whether we're paginating backwards("b") or forwards ("f"). + direction: Whether we're paginating backwards or forwards. column_names: The column names to bound. Must *not* be user defined as these get inserted directly into the SQL statement without escapes. from_token: The start point for the pagination. This is an exclusive - minimum bound if direction is "f", and an inclusive maximum bound if - direction is "b". + minimum bound if direction is forwards, and an inclusive maximum bound if + direction is backwards. to_token: The endpoint point for the pagination. This is an inclusive - maximum bound if direction is "f", and an exclusive minimum bound if - direction is "b". + maximum bound if direction is forwards, and an exclusive minimum bound if + direction is backwards. engine: The database engine to generate the clauses for Returns: The sql expression """ - assert direction in ("b", "f") where_clause = [] if from_token: where_clause.append( _make_generic_sql_bound( - bound=">=" if direction == "b" else "<", + bound=">=" if direction == Direction.BACKWARDS else "<", column_names=column_names, values=from_token, engine=engine, @@ -160,7 +159,7 @@ def generate_pagination_where_clause( if to_token: where_clause.append( _make_generic_sql_bound( - bound="<" if direction == "b" else ">=", + bound="<" if direction == Direction.BACKWARDS else ">=", column_names=column_names, values=to_token, engine=engine, @@ -171,7 +170,7 @@ def generate_pagination_where_clause( def generate_pagination_bounds( - direction: str, + direction: Direction, from_token: Optional[RoomStreamToken], to_token: Optional[RoomStreamToken], ) -> Tuple[ @@ -181,7 +180,7 @@ def generate_pagination_bounds( Generate a start and end point for this page of events. Args: - direction: Whether pagination is going forwards or backwards. One of "f" or "b". + direction: Whether pagination is going forwards or backwards. from_token: The token to start pagination at, or None to start at the first value. to_token: The token to end pagination at, or None to not limit the end point. @@ -201,7 +200,7 @@ def generate_pagination_bounds( # Tokens really represent positions between elements, but we use # the convention of pointing to the event before the gap. Hence # we have a bit of asymmetry when it comes to equalities. - if direction == "b": + if direction == Direction.BACKWARDS: order = "DESC" else: order = "ASC" @@ -215,7 +214,7 @@ def generate_pagination_bounds( if from_token: if from_token.topological is not None: from_bound = from_token.as_historical_tuple() - elif direction == "b": + elif direction == Direction.BACKWARDS: from_bound = ( None, from_token.get_max_stream_pos(), @@ -230,7 +229,7 @@ def generate_pagination_bounds( if to_token: if to_token.topological is not None: to_bound = to_token.as_historical_tuple() - elif direction == "b": + elif direction == Direction.BACKWARDS: to_bound = ( None, to_token.stream, @@ -245,20 +244,20 @@ def generate_pagination_bounds( def generate_next_token( - direction: str, last_topo_ordering: int, last_stream_ordering: int + direction: Direction, last_topo_ordering: int, last_stream_ordering: int ) -> RoomStreamToken: """ Generate the next room stream token based on the currently returned data. Args: - direction: Whether pagination is going forwards or backwards. One of "f" or "b". + direction: Whether pagination is going forwards or backwards. last_topo_ordering: The last topological ordering being returned. last_stream_ordering: The last stream ordering being returned. Returns: A new RoomStreamToken to return to the client. """ - if direction == "b": + if direction == Direction.BACKWARDS: # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk @@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, before_token, - direction="b", + direction=Direction.BACKWARDS, limit=before_limit, event_filter=event_filter, ) @@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): txn, room_id, after_token, - direction="f", + direction=Direction.FORWARDS, limit=after_limit, event_filter=event_filter, ) @@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_token: RoomStreamToken, to_token: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[_EventDictReturn], RoomStreamToken]: @@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_token: The token used to stream from to_token: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. @@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): _EventDictReturn(event_id, topological_ordering, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn if _filter_results( - lower_token=to_token if direction == "b" else from_token, - upper_token=from_token if direction == "b" else to_token, + lower_token=to_token + if direction == Direction.BACKWARDS + else from_token, + upper_token=from_token + if direction == Direction.BACKWARDS + else to_token, instance_name=instance_name, topological_ordering=topological_ordering, stream_ordering=stream_ordering, @@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id: str, from_key: RoomStreamToken, to_key: Optional[RoomStreamToken] = None, - direction: str = "b", + direction: Direction = Direction.BACKWARDS, limit: int = -1, event_filter: Optional[Filter] = None, ) -> Tuple[List[EventBase], RoomStreamToken]: @@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): room_id from_key: The token used to stream from to_key: A token which if given limits the results to only those before - direction: Either 'b' or 'f' to indicate whether we are paginating - forwards or backwards from `from_key`. + direction: Indicates whether we are paginating forwards or backwards + from `from_key`. limit: The maximum number of events to return. event_filter: If provided filters the events to those that match the filter. -- cgit 1.4.1