diff --git a/changelog.d/8279.misc b/changelog.d/8279.misc
new file mode 100644
index 0000000000..99f669001f
--- /dev/null
+++ b/changelog.d/8279.misc
@@ -0,0 +1 @@
+Add type hints to `StreamToken` and `RoomStreamToken` classes.
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e2ddb628ff..cc47e8b62c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1310,12 +1310,11 @@ class SyncHandler:
presence_source = self.event_sources.sources["presence"]
since_token = sync_result_builder.since_token
+ presence_key = None
+ include_offline = False
if since_token and not sync_result_builder.full_state:
presence_key = since_token.presence_key
include_offline = True
- else:
- presence_key = None
- include_offline = False
presence, presence_key = await presence_source.get_new_events(
user=user,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..306fc6947c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
}
async def get_users_whose_devices_changed(
- self, from_key: str, user_ids: Iterable[str]
+ self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
- from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def get_users_whose_signatures_changed(
- self, user_id: str, from_key: str
+ self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
A set of user IDs with updated signatures.
"""
- from_key = int(from_key)
+
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index be6df8a6d1..08a13a8b47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
- from_token: Optional[Tuple[int, int]],
- to_token: Optional[Tuple[int, int]],
+ from_token: Optional[Tuple[Optional[int], int]],
+ to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
@@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- end_token = RoomStreamToken.parse(end_token)
+ parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
- from_token=end_token,
+ from_token=parsed_end_token,
limit=limit,
)
@@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token,
- to_token=to_token,
+ from_token=from_token.as_tuple(),
+ to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
- from_key = RoomStreamToken.parse(from_key)
+ parsed_from_key = RoomStreamToken.parse(from_key)
+ parsed_to_key = None
if to_key:
- to_key = RoomStreamToken.parse(to_key)
+ parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
- from_key,
- to_key,
+ parsed_from_key,
+ parsed_to_key,
direction,
limit,
event_filter,
diff --git a/synapse/types.py b/synapse/types.py
index f7de48f148..ba45335038 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,7 @@ import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -362,22 +362,79 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-class StreamToken(
- namedtuple(
- "Token",
- (
- "room_key",
- "presence_key",
- "typing_key",
- "receipt_key",
- "account_data_key",
- "push_rules_key",
- "to_device_key",
- "device_list_key",
- "groups_key",
- ),
+@attr.s(frozen=True, slots=True)
+class RoomStreamToken:
+ """Tokens are positions between events. The token "s1" comes after event 1.
+
+ s0 s1
+ | |
+ [0] V [1] V [2]
+
+ Tokens can either be a point in the live event stream or a cursor going
+ through historic events.
+
+ When traversing the live event stream events are ordered by when they
+ arrived at the homeserver.
+
+ When traversing historic events the events are ordered by their depth in
+ the event graph "topological_ordering" and then by when they arrived at the
+ homeserver "stream_ordering".
+
+ Live tokens start with an "s" followed by the "stream_ordering" id of the
+ event it comes after. Historic tokens start with a "t" followed by the
+ "topological_ordering" id of the event it comes after, followed by "-",
+ followed by the "stream_ordering" id of the event it comes after.
+ """
+
+ topological = attr.ib(
+ type=Optional[int],
+ validator=attr.validators.optional(attr.validators.instance_of(int)),
)
-):
+ stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+
+ @classmethod
+ def parse(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ if string[0] == "t":
+ parts = string[1:].split("-", 1)
+ return cls(topological=int(parts[0]), stream=int(parts[1]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid token %r" % (string,))
+
+ @classmethod
+ def parse_stream_token(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid token %r" % (string,))
+
+ def as_tuple(self) -> Tuple[Optional[int], int]:
+ return (self.topological, self.stream)
+
+ def __str__(self) -> str:
+ if self.topological is not None:
+ return "t%d-%d" % (self.topological, self.stream)
+ else:
+ return "s%d" % (self.stream,)
+
+
+@attr.s(slots=True, frozen=True)
+class StreamToken:
+ room_key = attr.ib(type=str)
+ presence_key = attr.ib(type=int)
+ typing_key = attr.ib(type=int)
+ receipt_key = attr.ib(type=int)
+ account_data_key = attr.ib(type=int)
+ push_rules_key = attr.ib(type=int)
+ to_device_key = attr.ib(type=int)
+ device_list_key = attr.ib(type=int)
+ groups_key = attr.ib(type=int)
+
_SEPARATOR = "_"
START = None # type: StreamToken
@@ -385,15 +442,15 @@ class StreamToken(
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
- while len(keys) < len(cls._fields):
+ while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
- return cls(*keys)
+ return cls(keys[0], *(int(k) for k in keys[1:]))
except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
- return self._SEPARATOR.join([str(k) for k in self])
+ return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])
@property
def room_stream_id(self):
@@ -435,63 +492,10 @@ class StreamToken(
return self
def copy_and_replace(self, key, new_value):
- return self._replace(**{key: new_value})
-
-
-StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
-
-
-class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
- """Tokens are positions between events. The token "s1" comes after event 1.
-
- s0 s1
- | |
- [0] V [1] V [2]
-
- Tokens can either be a point in the live event stream or a cursor going
- through historic events.
-
- When traversing the live event stream events are ordered by when they
- arrived at the homeserver.
-
- When traversing historic events the events are ordered by their depth in
- the event graph "topological_ordering" and then by when they arrived at the
- homeserver "stream_ordering".
-
- Live tokens start with an "s" followed by the "stream_ordering" id of the
- event it comes after. Historic tokens start with a "t" followed by the
- "topological_ordering" id of the event it comes after, followed by "-",
- followed by the "stream_ordering" id of the event it comes after.
- """
+ return attr.evolve(self, **{key: new_value})
- __slots__ = [] # type: list
-
- @classmethod
- def parse(cls, string):
- try:
- if string[0] == "s":
- return cls(topological=None, stream=int(string[1:]))
- if string[0] == "t":
- parts = string[1:].split("-", 1)
- return cls(topological=int(parts[0]), stream=int(parts[1]))
- except Exception:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
- @classmethod
- def parse_stream_token(cls, string):
- try:
- if string[0] == "s":
- return cls(topological=None, stream=int(string[1:]))
- except Exception:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
-
- def __str__(self):
- if self.topological is not None:
- return "t%d-%d" % (self.topological, self.stream)
- else:
- return "s%d" % (self.stream,)
+StreamToken.START = StreamToken.from_string("s0_0")
class ThirdPartyInstanceID(
|