diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index be98b379eb..1b5262d667 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -699,10 +699,17 @@ class SlidingSyncHandler:
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
# Then assemble the `RoomStreamToken`
+ min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map`
- stream=min(instance_to_max_stream_ordering_map.values()),
- instance_map=immutabledict(instance_to_max_stream_ordering_map),
+ stream=min_stream_pos,
+ instance_map=immutabledict(
+ {
+ instance_name: stream_pos
+ for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
+ if stream_pos > min_stream_pos
+ }
+ ),
)
# Since we fetched the users room list at some point in time after the from/to
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index b22a13ef01..3962ecc996 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -20,6 +20,7 @@
#
#
import abc
+import logging
import re
import string
from enum import Enum
@@ -74,6 +75,9 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
+
+logger = logging.getLogger(__name__)
+
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
@@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream
position.
+
+ The values in `instance_map` must be greater than the `stream` attribute.
"""
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True,
)
+ def __attrs_post_init__(self) -> None:
+ # Enforce that all instances have a value greater than the min stream
+ # position.
+ for i, v in self.instance_map.items():
+ if v <= self.stream:
+ raise ValueError(
+ f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
+ )
+
@classmethod
@abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self":
@@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
for instance in set(self.instance_map).union(other.instance_map)
}
+ # Filter out any redundant entries.
+ instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
+
return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map)
)
@@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value"""
+ min_pos = min(self.stream, max_stream)
return type(self)(
- stream=min(self.stream, max_stream),
+ stream=min_pos,
instance_map=immutabledict(
- {k: min(s, max_stream) for k, s in self.instance_map.items()}
+ {
+ k: min(s, max_stream)
+ for k, s in self.instance_map.items()
+ if min(s, max_stream) > min_pos
+ }
),
)
@@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
)
+ super().__attrs_post_init__()
+
@classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try:
@@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
+ if not part:
+ # Handle tokens of the form `m5~`, which were created by
+ # a bug
+ continue
+
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
- pass
+ # We log an exception here as even though this *might* be a client
+ # handing a bad token, its more likely that Synapse returned a bad
+ # token (and we really want to catch those!).
+ logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod
@@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return self.instance_map.get(instance_name, self.stream)
async def to_string(self, store: "DataStore") -> str:
+ """See class level docstring for information about the format."""
+
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
@@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
- encoded_map = "~".join(entries)
- return f"m{self.stream}~{encoded_map}"
+ if entries:
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ return f"s{self.stream}"
else:
return "s%d" % (self.stream,)
@@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_map = {}
for part in parts[1:]:
+ if not part:
+ # Handle tokens of the form `m5~`, which were created by
+ # a bug
+ continue
+
key, value = part.split(".")
instance_id = int(key)
pos = int(value)
@@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
except CancelledError:
raise
except Exception:
- pass
+ # We log an exception here as even though this *might* be a client
+ # handing a bad token, its more likely that Synapse returned a bad
+ # token (and we really want to catch those!).
+ logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,))
async def to_string(self, store: "DataStore") -> str:
+ """See class level docstring for information about the format."""
+
if self.instance_map:
entries = []
for name, pos in self.instance_map.items():
@@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}")
- encoded_map = "~".join(entries)
- return f"m{self.stream}~{encoded_map}"
+ if entries:
+ encoded_map = "~".join(entries)
+ return f"m{self.stream}~{encoded_map}"
+ return str(self.stream)
else:
return str(self.stream)
|