From 63c0e9e1954fc7fc10a2575c54aecc8944de60f3 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 8 Sep 2020 16:48:15 +0100 Subject: Add types to StreamToken and RoomStreamToken (#8279) The intention here is to change `StreamToken.room_key` to be a `RoomStreamToken` in a future PR, but that is a big enough change without this refactoring too. --- changelog.d/8279.misc | 1 + synapse/handlers/sync.py | 5 +- synapse/storage/databases/main/devices.py | 7 +- synapse/storage/databases/main/stream.py | 21 +++-- synapse/types.py | 152 +++++++++++++++--------------- 5 files changed, 95 insertions(+), 91 deletions(-) create mode 100644 changelog.d/8279.misc diff --git a/changelog.d/8279.misc b/changelog.d/8279.misc new file mode 100644 index 0000000000..99f669001f --- /dev/null +++ b/changelog.d/8279.misc @@ -0,0 +1 @@ +Add type hints to `StreamToken` and `RoomStreamToken` classes. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e2ddb628ff..cc47e8b62c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1310,12 +1310,11 @@ class SyncHandler: presence_source = self.event_sources.sources["presence"] since_token = sync_result_builder.since_token + presence_key = None + include_offline = False if since_token and not sync_result_builder.full_state: presence_key = since_token.presence_key include_offline = True - else: - presence_key = None - include_offline = False presence, presence_key = await presence_source.get_new_events( user=user, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index add4e3ea0e..306fc6947c 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore): } async def get_users_whose_devices_changed( - self, from_key: str, user_ids: Iterable[str] + self, from_key: int, user_ids: Iterable[str] ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that are in the given list of user_ids. @@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore): Returns: The set of user_ids whose devices have changed since `from_key` """ - from_key = int(from_key) # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. @@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore): ) async def get_users_whose_signatures_changed( - self, user_id: str, from_key: str + self, user_id: str, from_key: int ) -> Set[str]: """Get the users who have new cross-signing signatures made by `user_id` since `from_key`. @@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore): Returns: A set of user IDs with updated signatures. """ - from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): sql = """ SELECT DISTINCT user_ids FROM user_signature_stream diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index be6df8a6d1..08a13a8b47 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -79,8 +79,8 @@ _EventDictReturn = namedtuple( def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], - from_token: Optional[Tuple[int, int]], - to_token: Optional[Tuple[int, int]], + from_token: Optional[Tuple[Optional[int], int]], + to_token: Optional[Tuple[Optional[int], int]], engine: BaseDatabaseEngine, ) -> str: """Creates an SQL expression to bound the columns by the pagination @@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if limit == 0: return [], end_token - end_token = RoomStreamToken.parse(end_token) + parsed_end_token = RoomStreamToken.parse(end_token) rows, token = await self.db_pool.runInteraction( "get_recent_event_ids_for_room", self._paginate_room_events_txn, room_id, - from_token=end_token, + from_token=parsed_end_token, limit=limit, ) @@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): bounds = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=from_token, - to_token=to_token, + from_token=from_token.as_tuple(), + to_token=to_token.as_tuple() if to_token else None, engine=self.database_engine, ) @@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): and `to_key`). """ - from_key = RoomStreamToken.parse(from_key) + parsed_from_key = RoomStreamToken.parse(from_key) + parsed_to_key = None if to_key: - to_key = RoomStreamToken.parse(to_key) + parsed_to_key = RoomStreamToken.parse(to_key) rows, token = await self.db_pool.runInteraction( "paginate_room_events", self._paginate_room_events_txn, room_id, - from_key, - to_key, + parsed_from_key, + parsed_to_key, direction, limit, event_filter, diff --git a/synapse/types.py b/synapse/types.py index f7de48f148..ba45335038 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -18,7 +18,7 @@ import re import string import sys from collections import namedtuple -from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar +from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -362,22 +362,79 @@ def map_username_to_mxid_localpart(username, case_sensitive=False): return username.decode("ascii") -class StreamToken( - namedtuple( - "Token", - ( - "room_key", - "presence_key", - "typing_key", - "receipt_key", - "account_data_key", - "push_rules_key", - "to_device_key", - "device_list_key", - "groups_key", - ), +@attr.s(frozen=True, slots=True) +class RoomStreamToken: + """Tokens are positions between events. The token "s1" comes after event 1. + + s0 s1 + | | + [0] V [1] V [2] + + Tokens can either be a point in the live event stream or a cursor going + through historic events. + + When traversing the live event stream events are ordered by when they + arrived at the homeserver. + + When traversing historic events the events are ordered by their depth in + the event graph "topological_ordering" and then by when they arrived at the + homeserver "stream_ordering". + + Live tokens start with an "s" followed by the "stream_ordering" id of the + event it comes after. Historic tokens start with a "t" followed by the + "topological_ordering" id of the event it comes after, followed by "-", + followed by the "stream_ordering" id of the event it comes after. + """ + + topological = attr.ib( + type=Optional[int], + validator=attr.validators.optional(attr.validators.instance_of(int)), ) -): + stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) + + @classmethod + def parse(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + if string[0] == "t": + parts = string[1:].split("-", 1) + return cls(topological=int(parts[0]), stream=int(parts[1])) + except Exception: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + @classmethod + def parse_stream_token(cls, string: str) -> "RoomStreamToken": + try: + if string[0] == "s": + return cls(topological=None, stream=int(string[1:])) + except Exception: + pass + raise SynapseError(400, "Invalid token %r" % (string,)) + + def as_tuple(self) -> Tuple[Optional[int], int]: + return (self.topological, self.stream) + + def __str__(self) -> str: + if self.topological is not None: + return "t%d-%d" % (self.topological, self.stream) + else: + return "s%d" % (self.stream,) + + +@attr.s(slots=True, frozen=True) +class StreamToken: + room_key = attr.ib(type=str) + presence_key = attr.ib(type=int) + typing_key = attr.ib(type=int) + receipt_key = attr.ib(type=int) + account_data_key = attr.ib(type=int) + push_rules_key = attr.ib(type=int) + to_device_key = attr.ib(type=int) + device_list_key = attr.ib(type=int) + groups_key = attr.ib(type=int) + _SEPARATOR = "_" START = None # type: StreamToken @@ -385,15 +442,15 @@ class StreamToken( def from_string(cls, string): try: keys = string.split(cls._SEPARATOR) - while len(keys) < len(cls._fields): + while len(keys) < len(attr.fields(cls)): # i.e. old token from before receipt_key keys.append("0") - return cls(*keys) + return cls(keys[0], *(int(k) for k in keys[1:])) except Exception: raise SynapseError(400, "Invalid Token") def to_string(self): - return self._SEPARATOR.join([str(k) for k in self]) + return self._SEPARATOR.join([str(k) for k in attr.astuple(self)]) @property def room_stream_id(self): @@ -435,63 +492,10 @@ class StreamToken( return self def copy_and_replace(self, key, new_value): - return self._replace(**{key: new_value}) - - -StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))) - - -class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): - """Tokens are positions between events. The token "s1" comes after event 1. - - s0 s1 - | | - [0] V [1] V [2] - - Tokens can either be a point in the live event stream or a cursor going - through historic events. - - When traversing the live event stream events are ordered by when they - arrived at the homeserver. - - When traversing historic events the events are ordered by their depth in - the event graph "topological_ordering" and then by when they arrived at the - homeserver "stream_ordering". - - Live tokens start with an "s" followed by the "stream_ordering" id of the - event it comes after. Historic tokens start with a "t" followed by the - "topological_ordering" id of the event it comes after, followed by "-", - followed by the "stream_ordering" id of the event it comes after. - """ + return attr.evolve(self, **{key: new_value}) - __slots__ = [] # type: list - - @classmethod - def parse(cls, string): - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - if string[0] == "t": - parts = string[1:].split("-", 1) - return cls(topological=int(parts[0]), stream=int(parts[1])) - except Exception: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - @classmethod - def parse_stream_token(cls, string): - try: - if string[0] == "s": - return cls(topological=None, stream=int(string[1:])) - except Exception: - pass - raise SynapseError(400, "Invalid token %r" % (string,)) - - def __str__(self): - if self.topological is not None: - return "t%d-%d" % (self.topological, self.stream) - else: - return "s%d" % (self.stream,) +StreamToken.START = StreamToken.from_string("s0_0") class ThirdPartyInstanceID( -- cgit 1.4.1