diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index b22a13ef01..c0d30ac2a3 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -20,6 +20,7 @@
#
#
import abc
+import logging
import re
import string
from enum import Enum
@@ -74,6 +75,9 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
+
+logger = logging.getLogger(__name__)
+
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
@@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
+
+ The values in `instance_map` must be greater than the `stream` attribute.
"""
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)
+ def __attrs_post_init__(self) -> None:
+ # Enforce that all instances have a value greater than the min stream
+ # position.
+ for i, v in self.instance_map.items():
+ if v <= self.stream:
+ raise ValueError(
+ f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
+ )
+
@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
for instance in set(self.instance_map).union(other.instance_map)
}
+ # Filter out any redundant entries.
+ instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
+
return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
@@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""
+ min_pos = min(self.stream, max_stream)
return type(self)(
- stream=min(self.stream, max_stream),
+ stream=min_pos,
instance_map=immutabledict(
- {k: min(s, max_stream) for k, s in self.instance_map.items()}
+ {
+ k: min(s, max_stream)
+ for k, s in self.instance_map.items()
+ if min(s, max_stream) > min_pos
+ }
),
)
@@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)
+ super().__attrs_post_init__()
+
@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
@@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
+ if not part:
+ # Handle tokens of the form `m5~`, which were created by
+ # a bug
+ continue
+
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
- pass
+ # We log an exception here as even though this *might* be a client
+ # handing a bad token, its more likely that Synapse returned a bad
+ # token (and we really want to catch those!).
+ logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
@@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return self.instance_map.get(instance_name, self.stream)
async def to_string(self, store: "DataStore") -> str:
+ """See class level docstring for information about the format."""
+
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
@@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
- encoded_map = "~".join(entries)
- return f"m{self.stream}~{encoded_map}"
+ if entries:
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ return f"s{self.stream}"
else:
return "s%d" % (self.stream,)
@@ -740,6 +777,13 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return super().bound_stream_token(max_stream)
+ def __str__(self) -> str:
+ instances = ", ".join(f"{k}: {v}" for k, v in sorted(self.instance_map.items()))
+ return (
+ f"RoomStreamToken(stream: {self.stream}, topological: {self.topological}, "
+ f"instances: {{{instances}}})"
+ )
+
@attr.s(frozen=True, slots=True, order=False)
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
@@ -756,6 +800,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
+ if not part:
+ # Handle tokens of the form `m5~`, which were created by
+ # a bug
+ continue
+
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -770,10 +819,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
- pass
+ # We log an exception here as even though this *might* be a client
+ # handing a bad token, its more likely that Synapse returned a bad
+ # token (and we really want to catch those!).
+ logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))
async def to_string(self, store: "DataStore") -> str:
+ """See class level docstring for information about the format."""
+
if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
@@ -786,8 +840,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
- encoded_map = "~".join(entries)
- return f"m{self.stream}~{encoded_map}"
+ if entries:
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ return str(self.stream)
else:
return str(self.stream)
@@ -824,6 +880,13 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
return True
+ def __str__(self) -> str:
+ instances = ", ".join(f"{k}: {v}" for k, v in sorted(self.instance_map.items()))
+ return (
+ f"MultiWriterStreamToken(stream: {self.stream}, "
+ f"instances: {{{instances}}})"
+ )
+
class StreamKeyType(Enum):
"""Known stream types.
@@ -1082,6 +1145,15 @@ class StreamToken:
return True
+ def __str__(self) -> str:
+ return (
+ f"StreamToken(room: {self.room_key}, presence: {self.presence_key}, "
+ f"typing: {self.typing_key}, receipt: {self.receipt_key}, "
+ f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
+ f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
+ f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})"
+ )
+
StreamToken.START = StreamToken(
RoomStreamToken(stream=0), 0, 0, MultiWriterStreamToken(stream=0), 0, 0, 0, 0, 0, 0
@@ -1170,11 +1242,12 @@ class ReadReceipt:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListUpdates:
"""
- An object containing a diff of information regarding other users' device lists, intended for
- a recipient to carry out device list tracking.
+ An object containing a diff of information regarding other users' device lists,
+ intended for a recipient to carry out device list tracking.
Attributes:
- changed: A set of users whose device lists have changed recently.
+ changed: A set of users who have updated their device identity or
+ cross-signing keys, or who now share an encrypted room with.
left: A set of users who the recipient no longer needs to track the device lists of.
Typically when those users no longer share any end-to-end encryption enabled rooms.
"""
diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py
index 43dcdf20dd..4c6c42db04 100644
--- a/synapse/types/handlers/__init__.py
+++ b/synapse/types/handlers/__init__.py
@@ -18,7 +18,7 @@
#
#
from enum import Enum
-from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Optional, Sequence, Tuple
import attr
from typing_extensions import TypedDict
@@ -31,7 +31,7 @@ else:
from pydantic import Extra
from synapse.events import EventBase
-from synapse.types import JsonDict, JsonMapping, StreamToken, UserID
+from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, StreamToken, UserID
from synapse.types.rest.client import SlidingSyncBody
if TYPE_CHECKING:
@@ -200,18 +200,24 @@ class SlidingSyncResult:
flag set. (same as sync v2)
"""
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class StrippedHero:
+ user_id: str
+ display_name: Optional[str]
+ avatar_url: Optional[str]
+
name: Optional[str]
avatar: Optional[str]
- heroes: Optional[List[EventBase]]
+ heroes: Optional[List[StrippedHero]]
is_dm: bool
initial: bool
- # Only optional because it won't be included for invite/knock rooms with `stripped_state`
- required_state: Optional[List[EventBase]]
- # Only optional because it won't be included for invite/knock rooms with `stripped_state`
- timeline_events: Optional[List[EventBase]]
+ # Should be empty for invite/knock rooms with `stripped_state`
+ required_state: List[EventBase]
+ # Should be empty for invite/knock rooms with `stripped_state`
+ timeline_events: List[EventBase]
bundled_aggregations: Optional[Dict[str, "BundledAggregations"]]
# Optional because it's only relevant to invite/knock rooms
- stripped_state: Optional[List[JsonDict]]
+ stripped_state: List[JsonDict]
# Only optional because it won't be included for invite/knock rooms with `stripped_state`
prev_batch: Optional[StreamToken]
# Only optional because it won't be included for invite/knock rooms with `stripped_state`
@@ -252,10 +258,81 @@ class SlidingSyncResult:
count: int
ops: List[Operation]
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class Extensions:
+ """Responses for extensions
+
+ Attributes:
+ to_device: The to-device extension (MSC3885)
+ e2ee: The E2EE device extension (MSC3884)
+ """
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class ToDeviceExtension:
+ """The to-device extension (MSC3885)
+
+ Attributes:
+ next_batch: The to-device stream token the client should use
+ to get more results
+ events: A list of to-device messages for the client
+ """
+
+ next_batch: str
+ events: Sequence[JsonMapping]
+
+ def __bool__(self) -> bool:
+ return bool(self.events)
+
+ @attr.s(slots=True, frozen=True, auto_attribs=True)
+ class E2eeExtension:
+ """The E2EE device extension (MSC3884)
+
+ Attributes:
+ device_list_updates: List of user_ids whose devices have changed or left (only
+ present on incremental syncs).
+ device_one_time_keys_count: Map from key algorithm to the number of
+ unclaimed one-time keys currently held on the server for this device. If
+ an algorithm is unlisted, the count for that algorithm is assumed to be
+ zero. If this entire parameter is missing, the count for all algorithms
+ is assumed to be zero.
+ device_unused_fallback_key_types: List of unused fallback key algorithms
+ for this device.
+ """
+
+ # Only present on incremental syncs
+ device_list_updates: Optional[DeviceListUpdates]
+ device_one_time_keys_count: Mapping[str, int]
+ device_unused_fallback_key_types: Sequence[str]
+
+ def __bool__(self) -> bool:
+ # Note that "signed_curve25519" is always returned in key count responses
+ # regardless of whether we uploaded any keys for it. This is necessary until
+ # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
+ #
+ # Also related:
+ # https://github.com/element-hq/element-android/issues/3725 and
+ # https://github.com/matrix-org/synapse/issues/10456
+ default_otk = self.device_one_time_keys_count.get("signed_curve25519")
+ more_than_default_otk = len(self.device_one_time_keys_count) > 1 or (
+ default_otk is not None and default_otk > 0
+ )
+
+ return bool(
+ more_than_default_otk
+ or self.device_list_updates
+ or self.device_unused_fallback_key_types
+ )
+
+ to_device: Optional[ToDeviceExtension] = None
+ e2ee: Optional[E2eeExtension] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.to_device or self.e2ee)
+
next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult]
- extensions: JsonMapping
+ extensions: Extensions
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -271,5 +348,5 @@ class SlidingSyncResult:
next_pos=next_pos,
lists={},
rooms={},
- extensions={},
+ extensions=SlidingSyncResult.Extensions(),
)
diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py
index 55f6b44053..f3c45a0d6a 100644
--- a/synapse/types/rest/client/__init__.py
+++ b/synapse/types/rest/client/__init__.py
@@ -200,9 +200,6 @@ class SlidingSyncBody(RequestBodyModel):
}
timeline_limit: The maximum number of timeline events to return per response.
- include_heroes: Return a stripped variant of membership events (containing
- `user_id` and optionally `avatar_url` and `displayname`) for the users used
- to calculate the room name.
filters: Filters to apply to the list before sorting.
"""
@@ -270,16 +267,63 @@ class SlidingSyncBody(RequestBodyModel):
else:
ranges: Optional[List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]]] = None # type: ignore[valid-type]
slow_get_all_rooms: Optional[StrictBool] = False
- include_heroes: Optional[StrictBool] = False
filters: Optional[Filters] = None
class RoomSubscription(CommonRoomParameters):
pass
- class Extension(RequestBodyModel):
- enabled: Optional[StrictBool] = False
- lists: Optional[List[StrictStr]] = None
- rooms: Optional[List[StrictStr]] = None
+ class Extensions(RequestBodyModel):
+ """The extensions section of the request.
+
+ Extensions MUST have an `enabled` flag which defaults to `false`. If a client
+ sends an unknown extension name, the server MUST ignore it (or else backwards
+ compatibility between clients and servers is broken when a newer client tries to
+ communicate with an older server).
+ """
+
+ class ToDeviceExtension(RequestBodyModel):
+ """The to-device extension (MSC3885)
+
+ Attributes:
+ enabled
+ limit: Maximum number of to-device messages to return
+ since: The `next_batch` from the previous sync response
+ """
+
+ enabled: Optional[StrictBool] = False
+ limit: StrictInt = 100
+ since: Optional[StrictStr] = None
+
+ @validator("since")
+ def since_token_check(
+ cls, value: Optional[StrictStr]
+ ) -> Optional[StrictStr]:
+ # `since` comes in as an opaque string token but we know that it's just
+ # an integer representing the position in the device inbox stream. We
+ # want to pre-validate it to make sure it works fine in downstream code.
+ if value is None:
+ return value
+
+ try:
+ int(value)
+ except ValueError:
+ raise ValueError(
+ "'extensions.to_device.since' is invalid (should look like an int)"
+ )
+
+ return value
+
+ class E2eeExtension(RequestBodyModel):
+ """The E2EE device extension (MSC3884)
+
+ Attributes:
+ enabled
+ """
+
+ enabled: Optional[StrictBool] = False
+
+ to_device: Optional[ToDeviceExtension] = None
+ e2ee: Optional[E2eeExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
@@ -287,7 +331,7 @@ class SlidingSyncBody(RequestBodyModel):
else:
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
- extensions: Optional[Dict[StrictStr, Extension]] = None
+ extensions: Optional[Extensions] = None
@validator("lists")
def lists_length_check(
|