diff options
-rw-r--r-- | changelog.d/17386.bugfix | 1 | ||||
-rw-r--r-- | synapse/notifier.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/account_data.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/events_worker.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/presence.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/push_rule.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 10 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 11 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 3 | ||||
-rw-r--r-- | synapse/storage/util/id_generators.py | 5 | ||||
-rw-r--r-- | synapse/storage/util/sequence.py | 24 | ||||
-rw-r--r-- | synapse/streams/events.py | 64 | ||||
-rw-r--r-- | synapse/types/__init__.py | 18 | ||||
-rw-r--r-- | tests/handlers/test_sync.py | 73 | ||||
-rw-r--r-- | tests/rest/client/test_sync.py | 4 |
17 files changed, 229 insertions, 31 deletions
diff --git a/changelog.d/17386.bugfix b/changelog.d/17386.bugfix new file mode 100644 index 0000000000..9686b5c276 --- /dev/null +++ b/changelog.d/17386.bugfix @@ -0,0 +1 @@ +Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0. diff --git a/synapse/notifier.py b/synapse/notifier.py index c87eb748c0..c3ecf86ec4 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -764,6 +764,13 @@ class Notifier: async def wait_for_stream_token(self, stream_token: StreamToken) -> bool: """Wait for this worker to catch up with the given stream token.""" + current_token = self.event_sources.get_current_token() + if stream_token.is_before_or_eq(current_token): + return True + + # Work around a bug where older Synapse versions gave out tokens "from + # the future", i.e. that are ahead of the tokens persisted in the DB. + stream_token = await self.event_sources.bound_future_token(stream_token) start = self.clock.time_msec() while True: diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 9611a84932..966393869b 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -43,10 +43,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) self._instance_name in hs.config.worker.writers.account_data ) - self._account_data_id_gen: AbstractStreamIdGenerator + self._account_data_id_gen: MultiWriterIdGenerator self._account_data_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) """ return self._account_data_id_gen.get_current_token() + def get_account_data_id_generator(self) -> MultiWriterIdGenerator: + return self._account_data_id_gen + @cached() async def get_global_account_data_for_user( self, user_id: str diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 07333efff8..304ac42411 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -50,10 +50,7 @@ from synapse.storage.database import ( LoggingTransaction, make_in_list_sql_clause, ) -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache @@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): self._instance_name in hs.config.worker.writers.to_device ) - self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator( + self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( db_conn=db_conn, db=database, notifier=hs.get_replication_notifier(), @@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore): def get_to_device_stream_token(self) -> int: return self._to_device_msg_id_gen.get_current_token() + def get_to_device_id_generator(self) -> MultiWriterIdGenerator: + return self._to_device_msg_id_gen + async def get_messages_for_user_devices( self, user_ids: Collection[str], diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 59a035dd62..53024bddc3 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_device_stream_token(self) -> int: return self._device_list_id_gen.get_current_token() + def get_device_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._device_list_id_gen + async def count_devices_by_users( self, user_ids: Optional[Collection[str]] = None ) -> int: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e264d36f02..198e65cfa5 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore): ): super().__init__(database, db_conn, hs) - self._stream_id_gen: AbstractStreamIdGenerator - self._backfill_id_gen: AbstractStreamIdGenerator + self._stream_id_gen: MultiWriterIdGenerator + self._backfill_id_gen: MultiWriterIdGenerator self._stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 923e764491..065c885603 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -42,10 +42,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines._base import IsolationLevel from synapse.storage.types import Connection -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.iterutils import batch_iter @@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() - self._presence_id_gen: AbstractStreamIdGenerator + self._presence_id_gen: MultiWriterIdGenerator self._can_persist_presence = ( self._instance_name in hs.config.worker.writers.presence @@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) def get_current_presence_token(self) -> int: return self._presence_id_gen.get_current_token() + def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._presence_id_gen + def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]: """Fetch non-offline presence from the database so that we can register the appropriate time outs. diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2a39dc9f90..bbdde17711 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -178,6 +178,9 @@ class PushRulesWorkerStore( """ return self._push_rules_stream_id_gen.get_current_token() + def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._push_rules_stream_id_gen + def process_replication_rows( self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 8432560a89..3bde0ae0d4 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -45,10 +45,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines._base import IsolationLevel -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.types import ( JsonDict, JsonMapping, @@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # In the worker store this is an ID tracker which we overwrite in the non-worker # class below that is used on the main process. - self._receipts_id_gen: AbstractStreamIdGenerator + self._receipts_id_gen: MultiWriterIdGenerator self._can_write_to_receipts = ( self._instance_name in hs.config.worker.writers.receipts @@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore): def get_receipt_stream_id_for_instance(self, instance_name: str) -> int: return self._receipts_id_gen.get_current_token_for_writer(instance_name) + def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator: + return self._receipts_id_gen + def get_last_unthreaded_receipt_for_user_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index d5627b1d6e..80a4bf95f2 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -59,11 +59,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import ( - AbstractStreamIdGenerator, - IdGenerator, - MultiWriterIdGenerator, -) +from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): self.config: HomeServerConfig = hs.config - self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator + self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator( db_conn=db_conn, @@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): instance_name ) + def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator: + return self._un_partial_stated_rooms_stream_id_gen + async def get_un_partial_stated_rooms_between( self, last_id: int, current_id: int, room_ids: Collection[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index ff0d723684..b7eb3116ae 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions)) + def get_events_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._stream_id_gen + async def get_room_events_stream_for_rooms( self, room_ids: Collection[str], diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 48f88a6f8a..e8588f33cf 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): pos = self.get_current_token_for_writer(self._instance_name) txn.execute(sql, (self._stream_name, self._instance_name, pos)) + async def get_max_allocated_token(self) -> int: + return await self._db.runInteraction( + "get_max_allocated_token", self._sequence_gen.get_max_allocated + ) + @attr.s(frozen=True, auto_attribs=True) class _AsyncCtxManagerWrapper(Generic[T]): diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index c4c0602b28..cac3eba1a5 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta): """ ... + @abc.abstractmethod + def get_max_allocated(self, txn: Cursor) -> int: + """Get the maximum ID that we have allocated""" + class PostgresSequenceGenerator(SequenceGenerator): """An implementation of SequenceGenerator which uses a postgres sequence""" @@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator): % {"seq": self._sequence_name, "stream_name": stream_name} ) + def get_max_allocated(self, txn: Cursor) -> int: + # We just read from the sequence what the last value we fetched was. + txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}") + row = txn.fetchone() + assert row is not None + + last_value, is_called = row + if not is_called: + last_value -= 1 + return last_value + GetFirstCallbackType = Callable[[Cursor], int] @@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator): # There is nothing to do for in memory sequences pass + def get_max_allocated(self, txn: Cursor) -> int: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + return self._current_max_id + def build_sequence_generator( db_conn: "LoggingDatabaseConnection", diff --git a/synapse/streams/events.py b/synapse/streams/events.py index dd7401ac8e..93d5ae1a55 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.logging.opentracing import trace from synapse.streams import EventSource -from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken +from synapse.types import ( + AbstractMultiWriterStreamToken, + MultiWriterStreamToken, + StreamKeyType, + StreamToken, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -91,6 +96,63 @@ class EventSources: ) return token + async def bound_future_token(self, token: StreamToken) -> StreamToken: + """Bound a token that is ahead of the current token to the maximum + persisted values. + + This ensures that if we wait for the given token we know the stream will + eventually advance to that point. + + This works around a bug where older Synapse versions will give out + tokens for streams, and then after a restart will give back tokens where + the stream has "gone backwards". + """ + + current_token = self.get_current_token() + + stream_key_to_id_gen = { + StreamKeyType.ROOM: self.store.get_events_stream_id_generator(), + StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(), + StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(), + StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(), + StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(), + StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(), + StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), + StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), + } + + for _, key in StreamKeyType.__members__.items(): + if key == StreamKeyType.TYPING: + # Typing stream is allowed to "reset", and so comparisons don't + # really make sense as is. + # TODO: Figure out a better way of tracking resets. + continue + + token_value = token.get_field(key) + current_value = current_token.get_field(key) + + if isinstance(token_value, AbstractMultiWriterStreamToken): + assert type(current_value) is type(token_value) + + if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type] + max_token = await stream_key_to_id_gen[ + key + ].get_max_allocated_token() + + token = token.copy_and_replace( + key, token.room_key.bound_stream_token(max_token) + ) + else: + assert isinstance(current_value, int) + if current_value < token_value: + max_token = await stream_key_to_id_gen[ + key + ].get_max_allocated_token() + + token = token.copy_and_replace(key, min(token_value, max_token)) + + return token + @trace async def get_start_token_for_pagination(self, room_id: str) -> StreamToken: """Get the start token for a given room to be used to paginate diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 151658df53..8ab9f90238 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): return True + def bound_stream_token(self, max_stream: int) -> "Self": + """Bound the stream positions to a maximum value""" + + return type(self)( + stream=min(self.stream, max_stream), + instance_map=immutabledict( + {k: min(s, max_stream) for k, s in self.instance_map.items()} + ), + ) + @attr.s(frozen=True, slots=True, order=False) class RoomStreamToken(AbstractMultiWriterStreamToken): @@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): else: return "s%d" % (self.stream,) + def bound_stream_token(self, max_stream: int) -> "RoomStreamToken": + """See super class""" + + # This only makes sense for stream tokens. + assert self.topological is None + + return super().bound_stream_token(max_stream) + @attr.s(frozen=True, slots=True, order=False) class MultiWriterStreamToken(AbstractMultiWriterStreamToken): diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 02371ce724..5319928c28 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized +from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules @@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester +from synapse.types import JsonDict, StreamKeyType, UserID, create_requester from synapse.util import Clock import tests.unittest @@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): self.fail("No push rules found") + def test_wait_for_future_sync_token(self) -> None: + """Test that if we receive a token that is ahead of our current token, + we'll wait until the stream position advances. + + This can happen if replication streams start lagging, and the client's + previous sync request was serviced by a worker ahead of ours. + """ + user = self.register_user("alice", "password") + + # We simulate a lagging stream by getting a stream ID from the ID gen + # and then waiting to mark it as "persisted". + presence_id_gen = self.store.get_presence_stream_id_gen() + ctx_mgr = presence_id_gen.get_next() + stream_id = self.get_success(ctx_mgr.__aenter__()) + + # Create the new token based on the stream ID above. + current_token = self.hs.get_event_sources().get_current_token() + since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id) + + sync_d = defer.ensureDeferred( + self.sync_handler.wait_for_sync_for_user( + create_requester(user), + generate_sync_config(user), + sync_version=SyncVersion.SYNC_V2, + request_key=generate_request_key(), + since_token=since_token, + timeout=0, + ) + ) + + # This should block waiting for the presence stream to update + self.pump() + self.assertFalse(sync_d.called) + + # Marking the stream ID as persisted should unblock the request. + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + self.get_success(sync_d, by=1.0) + + def test_wait_for_invalid_future_sync_token(self) -> None: + """Like the previous test, except we give a token that has a stream + position ahead of what is in the DB, i.e. its invalid and we shouldn't + wait for the stream to advance (as it may never do so). + + This can happen due to older versions of Synapse giving out stream + positions without persisting them in the DB, and so on restart the + stream would get reset back to an older position. + """ + user = self.register_user("alice", "password") + + # Create a token and arbitrarily advance one of the streams. + current_token = self.hs.get_event_sources().get_current_token() + since_token = current_token.copy_and_advance( + StreamKeyType.PRESENCE, current_token.presence_key + 1 + ) + + sync_d = defer.ensureDeferred( + self.sync_handler.wait_for_sync_for_user( + create_requester(user), + generate_sync_config(user), + sync_version=SyncVersion.SYNC_V2, + request_key=generate_request_key(), + since_token=since_token, + timeout=0, + ) + ) + + # We should return without waiting for the presence stream to advance. + self.get_success(sync_d) + def generate_sync_config( user_id: str, diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index bfb26139d3..12c11f342c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Create a future token that will cause us to wait. Since we never send a new # event to reach that future stream_ordering, the worker will wait until the # full timeout. + stream_id_gen = self.store.get_events_stream_id_generator() + stream_id = self.get_success(stream_id_gen.get_next().__aenter__()) current_token = self.event_sources.get_current_token() future_position_token = current_token.copy_and_replace( StreamKeyType.ROOM, - RoomStreamToken(stream=current_token.room_key.stream + 1), + RoomStreamToken(stream=stream_id), ) future_position_token_serialized = self.get_success( |