diff --git a/synapse/notifier.py b/synapse/notifier.py
index c87eb748c0..c3ecf86ec4 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -764,6 +764,13 @@ class Notifier:
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
"""Wait for this worker to catch up with the given stream token."""
+ current_token = self.event_sources.get_current_token()
+ if stream_token.is_before_or_eq(current_token):
+ return True
+
+ # Work around a bug where older Synapse versions gave out tokens "from
+ # the future", i.e. that are ahead of the tokens persisted in the DB.
+ stream_token = await self.event_sources.bound_future_token(stream_token)
start = self.clock.time_msec()
while True:
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9611a84932..966393869b 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,10 +43,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._instance_name in hs.config.worker.writers.account_data
)
- self._account_data_id_gen: AbstractStreamIdGenerator
+ self._account_data_id_gen: MultiWriterIdGenerator
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"""
return self._account_data_id_gen.get_current_token()
+ def get_account_data_id_generator(self) -> MultiWriterIdGenerator:
+ return self._account_data_id_gen
+
@cached()
async def get_global_account_data_for_user(
self, user_id: str
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 5a752b9b8c..042d595ea0 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,10 +50,7 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._instance_name in hs.config.worker.writers.to_device
)
- self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+ self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
@@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self) -> int:
return self._to_device_msg_id_gen.get_current_token()
+ def get_to_device_id_generator(self) -> MultiWriterIdGenerator:
+ return self._to_device_msg_id_gen
+
async def get_messages_for_user_devices(
self,
user_ids: Collection[str],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 59a035dd62..53024bddc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
+ def get_device_stream_id_generator(self) -> MultiWriterIdGenerator:
+ return self._device_list_id_gen
+
async def count_devices_by_users(
self, user_ids: Optional[Collection[str]] = None
) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e264d36f02..198e65cfa5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._stream_id_gen: AbstractStreamIdGenerator
- self._backfill_id_gen: AbstractStreamIdGenerator
+ self._stream_id_gen: MultiWriterIdGenerator
+ self._backfill_id_gen: MultiWriterIdGenerator
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 923e764491..065c885603 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -42,10 +42,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
@@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
- self._presence_id_gen: AbstractStreamIdGenerator
+ self._presence_id_gen: MultiWriterIdGenerator
self._can_persist_presence = (
self._instance_name in hs.config.worker.writers.presence
@@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()
+ def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator:
+ return self._presence_id_gen
+
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2a39dc9f90..bbdde17711 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -178,6 +178,9 @@ class PushRulesWorkerStore(
"""
return self._push_rules_stream_id_gen.get_current_token()
+ def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator:
+ return self._push_rules_stream_id_gen
+
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 8432560a89..3bde0ae0d4 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -45,10 +45,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines._base import IsolationLevel
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import (
JsonDict,
JsonMapping,
@@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._receipts_id_gen: AbstractStreamIdGenerator
+ self._receipts_id_gen: MultiWriterIdGenerator
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
@@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
return self._receipts_id_gen.get_current_token_for_writer(instance_name)
+ def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator:
+ return self._receipts_id_gen
+
def get_last_unthreaded_receipt_for_user_txn(
self,
txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d5627b1d6e..80a4bf95f2 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -59,11 +59,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- IdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self.config: HomeServerConfig = hs.config
- self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+ self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
instance_name
)
+ def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator:
+ return self._un_partial_stated_rooms_stream_id_gen
+
async def get_un_partial_stated_rooms_between(
self, last_id: int, current_id: int, room_ids: Collection[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ff0d723684..b7eb3116ae 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
+ def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
+ return self._stream_id_gen
+
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 48f88a6f8a..e8588f33cf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
pos = self.get_current_token_for_writer(self._instance_name)
txn.execute(sql, (self._stream_name, self._instance_name, pos))
+ async def get_max_allocated_token(self) -> int:
+ return await self._db.runInteraction(
+ "get_max_allocated_token", self._sequence_gen.get_max_allocated
+ )
+
@attr.s(frozen=True, auto_attribs=True)
class _AsyncCtxManagerWrapper(Generic[T]):
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c4c0602b28..cac3eba1a5 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""
...
+ @abc.abstractmethod
+ def get_max_allocated(self, txn: Cursor) -> int:
+ """Get the maximum ID that we have allocated"""
+
class PostgresSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator):
% {"seq": self._sequence_name, "stream_name": stream_name}
)
+ def get_max_allocated(self, txn: Cursor) -> int:
+ # We just read from the sequence what the last value we fetched was.
+ txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}")
+ row = txn.fetchone()
+ assert row is not None
+
+ last_value, is_called = row
+ if not is_called:
+ last_value -= 1
+ return last_value
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator):
# There is nothing to do for in memory sequences
pass
+ def get_max_allocated(self, txn: Cursor) -> int:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ return self._current_max_id
+
def build_sequence_generator(
db_conn: "LoggingDatabaseConnection",
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index dd7401ac8e..93d5ae1a55 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
+from synapse.types import (
+ AbstractMultiWriterStreamToken,
+ MultiWriterStreamToken,
+ StreamKeyType,
+ StreamToken,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -91,6 +96,63 @@ class EventSources:
)
return token
+ async def bound_future_token(self, token: StreamToken) -> StreamToken:
+ """Bound a token that is ahead of the current token to the maximum
+ persisted values.
+
+ This ensures that if we wait for the given token we know the stream will
+ eventually advance to that point.
+
+ This works around a bug where older Synapse versions will give out
+ tokens for streams, and then after a restart will give back tokens where
+ the stream has "gone backwards".
+ """
+
+ current_token = self.get_current_token()
+
+ stream_key_to_id_gen = {
+ StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
+ StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
+ StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
+ StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
+ StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
+ StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
+ StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
+ }
+
+ for _, key in StreamKeyType.__members__.items():
+ if key == StreamKeyType.TYPING:
+ # Typing stream is allowed to "reset", and so comparisons don't
+ # really make sense as is.
+ # TODO: Figure out a better way of tracking resets.
+ continue
+
+ token_value = token.get_field(key)
+ current_value = current_token.get_field(key)
+
+ if isinstance(token_value, AbstractMultiWriterStreamToken):
+ assert type(current_value) is type(token_value)
+
+ if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type]
+ max_token = await stream_key_to_id_gen[
+ key
+ ].get_max_allocated_token()
+
+ token = token.copy_and_replace(
+ key, token.room_key.bound_stream_token(max_token)
+ )
+ else:
+ assert isinstance(current_value, int)
+ if current_value < token_value:
+ max_token = await stream_key_to_id_gen[
+ key
+ ].get_max_allocated_token()
+
+ token = token.copy_and_replace(key, min(token_value, max_token))
+
+ return token
+
@trace
async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the start token for a given room to be used to paginate
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 151658df53..8ab9f90238 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
return True
+ def bound_stream_token(self, max_stream: int) -> "Self":
+ """Bound the stream positions to a maximum value"""
+
+ return type(self)(
+ stream=min(self.stream, max_stream),
+ instance_map=immutabledict(
+ {k: min(s, max_stream) for k, s in self.instance_map.items()}
+ ),
+ )
+
@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
@@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
else:
return "s%d" % (self.stream,)
+ def bound_stream_token(self, max_stream: int) -> "RoomStreamToken":
+ """See super class"""
+
+ # This only makes sense for stream tokens.
+ assert self.topological is None
+
+ return super().bound_stream_token(max_stream)
+
@attr.s(frozen=True, slots=True, order=False)
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|