diff --git a/changelog.d/17386.bugfix b/changelog.d/17386.bugfix
new file mode 100644
index 0000000000..9686b5c276
--- /dev/null
+++ b/changelog.d/17386.bugfix
@@ -0,0 +1 @@
+Fix bug where `/sync` requests could get blocked indefinitely after an upgrade from Synapse versions before v1.109.0.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c87eb748c0..c3ecf86ec4 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -764,6 +764,13 @@ class Notifier:
async def wait_for_stream_token(self, stream_token: StreamToken) -> bool:
"""Wait for this worker to catch up with the given stream token."""
+ current_token = self.event_sources.get_current_token()
+ if stream_token.is_before_or_eq(current_token):
+ return True
+
+ # Work around a bug where older Synapse versions gave out tokens "from
+ # the future", i.e. that are ahead of the tokens persisted in the DB.
+ stream_token = await self.event_sources.bound_future_token(stream_token)
start = self.clock.time_msec()
while True:
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9611a84932..966393869b 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -43,10 +43,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -71,7 +68,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._instance_name in hs.config.worker.writers.account_data
)
- self._account_data_id_gen: AbstractStreamIdGenerator
+ self._account_data_id_gen: MultiWriterIdGenerator
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -113,6 +110,9 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"""
return self._account_data_id_gen.get_current_token()
+ def get_account_data_id_generator(self) -> MultiWriterIdGenerator:
+ return self._account_data_id_gen
+
@cached()
async def get_global_account_data_for_user(
self, user_id: str
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 07333efff8..304ac42411 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -50,10 +50,7 @@ from synapse.storage.database import (
LoggingTransaction,
make_in_list_sql_clause,
)
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@@ -92,7 +89,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._instance_name in hs.config.worker.writers.to_device
)
- self._to_device_msg_id_gen: AbstractStreamIdGenerator = MultiWriterIdGenerator(
+ self._to_device_msg_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
notifier=hs.get_replication_notifier(),
@@ -169,6 +166,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self) -> int:
return self._to_device_msg_id_gen.get_current_token()
+ def get_to_device_id_generator(self) -> MultiWriterIdGenerator:
+ return self._to_device_msg_id_gen
+
async def get_messages_for_user_devices(
self,
user_ids: Collection[str],
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 59a035dd62..53024bddc3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -243,6 +243,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
+ def get_device_stream_id_generator(self) -> MultiWriterIdGenerator:
+ return self._device_list_id_gen
+
async def count_devices_by_users(
self, user_ids: Optional[Collection[str]] = None
) -> int:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e264d36f02..198e65cfa5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -192,8 +192,8 @@ class EventsWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
- self._stream_id_gen: AbstractStreamIdGenerator
- self._backfill_id_gen: AbstractStreamIdGenerator
+ self._stream_id_gen: MultiWriterIdGenerator
+ self._backfill_id_gen: MultiWriterIdGenerator
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 923e764491..065c885603 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -42,10 +42,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines._base import IsolationLevel
from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
@@ -83,7 +80,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
- self._presence_id_gen: AbstractStreamIdGenerator
+ self._presence_id_gen: MultiWriterIdGenerator
self._can_persist_presence = (
self._instance_name in hs.config.worker.writers.presence
@@ -455,6 +452,9 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
def get_current_presence_token(self) -> int:
return self._presence_id_gen.get_current_token()
+ def get_presence_stream_id_gen(self) -> MultiWriterIdGenerator:
+ return self._presence_id_gen
+
def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 2a39dc9f90..bbdde17711 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -178,6 +178,9 @@ class PushRulesWorkerStore(
"""
return self._push_rules_stream_id_gen.get_current_token()
+ def get_push_rules_stream_id_gen(self) -> MultiWriterIdGenerator:
+ return self._push_rules_stream_id_gen
+
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 8432560a89..3bde0ae0d4 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -45,10 +45,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines._base import IsolationLevel
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import (
JsonDict,
JsonMapping,
@@ -76,7 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process.
- self._receipts_id_gen: AbstractStreamIdGenerator
+ self._receipts_id_gen: MultiWriterIdGenerator
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
@@ -136,6 +133,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
return self._receipts_id_gen.get_current_token_for_writer(instance_name)
+ def get_receipts_stream_id_gen(self) -> MultiWriterIdGenerator:
+ return self._receipts_id_gen
+
def get_last_unthreaded_receipt_for_user_txn(
self,
txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d5627b1d6e..80a4bf95f2 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -59,11 +59,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
- AbstractStreamIdGenerator,
- IdGenerator,
- MultiWriterIdGenerator,
-)
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -151,7 +147,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self.config: HomeServerConfig = hs.config
- self._un_partial_stated_rooms_stream_id_gen: AbstractStreamIdGenerator
+ self._un_partial_stated_rooms_stream_id_gen: MultiWriterIdGenerator
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
@@ -1409,6 +1405,9 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
instance_name
)
+ def get_un_partial_stated_rooms_id_generator(self) -> MultiWriterIdGenerator:
+ return self._un_partial_stated_rooms_stream_id_gen
+
async def get_un_partial_stated_rooms_between(
self, last_id: int, current_id: int, room_ids: Collection[str]
) -> Set[str]:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ff0d723684..b7eb3116ae 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -577,6 +577,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions))
+ def get_events_stream_id_generator(self) -> MultiWriterIdGenerator:
+ return self._stream_id_gen
+
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 48f88a6f8a..e8588f33cf 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -812,6 +812,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
pos = self.get_current_token_for_writer(self._instance_name)
txn.execute(sql, (self._stream_name, self._instance_name, pos))
+ async def get_max_allocated_token(self) -> int:
+ return await self._db.runInteraction(
+ "get_max_allocated_token", self._sequence_gen.get_max_allocated
+ )
+
@attr.s(frozen=True, auto_attribs=True)
class _AsyncCtxManagerWrapper(Generic[T]):
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c4c0602b28..cac3eba1a5 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -88,6 +88,10 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""
...
+ @abc.abstractmethod
+ def get_max_allocated(self, txn: Cursor) -> int:
+ """Get the maximum ID that we have allocated"""
+
class PostgresSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -190,6 +194,17 @@ class PostgresSequenceGenerator(SequenceGenerator):
% {"seq": self._sequence_name, "stream_name": stream_name}
)
+ def get_max_allocated(self, txn: Cursor) -> int:
+ # We just read from the sequence what the last value we fetched was.
+ txn.execute(f"SELECT last_value, is_called FROM {self._sequence_name}")
+ row = txn.fetchone()
+ assert row is not None
+
+ last_value, is_called = row
+ if not is_called:
+ last_value -= 1
+ return last_value
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -248,6 +263,15 @@ class LocalSequenceGenerator(SequenceGenerator):
# There is nothing to do for in memory sequences
pass
+ def get_max_allocated(self, txn: Cursor) -> int:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ return self._current_max_id
+
def build_sequence_generator(
db_conn: "LoggingDatabaseConnection",
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index dd7401ac8e..93d5ae1a55 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -30,7 +30,12 @@ from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.logging.opentracing import trace
from synapse.streams import EventSource
-from synapse.types import MultiWriterStreamToken, StreamKeyType, StreamToken
+from synapse.types import (
+ AbstractMultiWriterStreamToken,
+ MultiWriterStreamToken,
+ StreamKeyType,
+ StreamToken,
+)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -91,6 +96,63 @@ class EventSources:
)
return token
+ async def bound_future_token(self, token: StreamToken) -> StreamToken:
+ """Bound a token that is ahead of the current token to the maximum
+ persisted values.
+
+ This ensures that if we wait for the given token we know the stream will
+ eventually advance to that point.
+
+ This works around a bug where older Synapse versions will give out
+ tokens for streams, and then after a restart will give back tokens where
+ the stream has "gone backwards".
+ """
+
+ current_token = self.get_current_token()
+
+ stream_key_to_id_gen = {
+ StreamKeyType.ROOM: self.store.get_events_stream_id_generator(),
+ StreamKeyType.PRESENCE: self.store.get_presence_stream_id_gen(),
+ StreamKeyType.RECEIPT: self.store.get_receipts_stream_id_gen(),
+ StreamKeyType.ACCOUNT_DATA: self.store.get_account_data_id_generator(),
+ StreamKeyType.PUSH_RULES: self.store.get_push_rules_stream_id_gen(),
+ StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
+ StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
+ StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
+ }
+
+ for _, key in StreamKeyType.__members__.items():
+ if key == StreamKeyType.TYPING:
+ # Typing stream is allowed to "reset", and so comparisons don't
+ # really make sense as is.
+ # TODO: Figure out a better way of tracking resets.
+ continue
+
+ token_value = token.get_field(key)
+ current_value = current_token.get_field(key)
+
+ if isinstance(token_value, AbstractMultiWriterStreamToken):
+ assert type(current_value) is type(token_value)
+
+ if not token_value.is_before_or_eq(current_value): # type: ignore[arg-type]
+ max_token = await stream_key_to_id_gen[
+ key
+ ].get_max_allocated_token()
+
+ token = token.copy_and_replace(
+ key, token.room_key.bound_stream_token(max_token)
+ )
+ else:
+ assert isinstance(current_value, int)
+ if current_value < token_value:
+ max_token = await stream_key_to_id_gen[
+ key
+ ].get_max_allocated_token()
+
+ token = token.copy_and_replace(key, min(token_value, max_token))
+
+ return token
+
@trace
async def get_start_token_for_pagination(self, room_id: str) -> StreamToken:
"""Get the start token for a given room to be used to paginate
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 151658df53..8ab9f90238 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -536,6 +536,16 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
return True
+ def bound_stream_token(self, max_stream: int) -> "Self":
+ """Bound the stream positions to a maximum value"""
+
+ return type(self)(
+ stream=min(self.stream, max_stream),
+ instance_map=immutabledict(
+ {k: min(s, max_stream) for k, s in self.instance_map.items()}
+ ),
+ )
+
@attr.s(frozen=True, slots=True, order=False)
class RoomStreamToken(AbstractMultiWriterStreamToken):
@@ -722,6 +732,14 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
else:
return "s%d" % (self.stream,)
+ def bound_stream_token(self, max_stream: int) -> "RoomStreamToken":
+ """See super class"""
+
+ # This only makes sense for stream tokens.
+ assert self.topological is None
+
+ return super().bound_stream_token(max_stream)
+
@attr.s(frozen=True, slots=True, order=False)
class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 02371ce724..5319928c28 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized
+from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules
@@ -35,7 +36,7 @@ from synapse.handlers.sync import SyncConfig, SyncRequestKey, SyncResult, SyncVe
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, StreamKeyType, UserID, create_requester
from synapse.util import Clock
import tests.unittest
@@ -959,6 +960,76 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.fail("No push rules found")
+ def test_wait_for_future_sync_token(self) -> None:
+ """Test that if we receive a token that is ahead of our current token,
+ we'll wait until the stream position advances.
+
+ This can happen if replication streams start lagging, and the client's
+ previous sync request was serviced by a worker ahead of ours.
+ """
+ user = self.register_user("alice", "password")
+
+ # We simulate a lagging stream by getting a stream ID from the ID gen
+ # and then waiting to mark it as "persisted".
+ presence_id_gen = self.store.get_presence_stream_id_gen()
+ ctx_mgr = presence_id_gen.get_next()
+ stream_id = self.get_success(ctx_mgr.__aenter__())
+
+ # Create the new token based on the stream ID above.
+ current_token = self.hs.get_event_sources().get_current_token()
+ since_token = current_token.copy_and_advance(StreamKeyType.PRESENCE, stream_id)
+
+ sync_d = defer.ensureDeferred(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(user),
+ generate_sync_config(user),
+ sync_version=SyncVersion.SYNC_V2,
+ request_key=generate_request_key(),
+ since_token=since_token,
+ timeout=0,
+ )
+ )
+
+ # This should block waiting for the presence stream to update
+ self.pump()
+ self.assertFalse(sync_d.called)
+
+ # Marking the stream ID as persisted should unblock the request.
+ self.get_success(ctx_mgr.__aexit__(None, None, None))
+
+ self.get_success(sync_d, by=1.0)
+
+ def test_wait_for_invalid_future_sync_token(self) -> None:
+ """Like the previous test, except we give a token that has a stream
+ position ahead of what is in the DB, i.e. its invalid and we shouldn't
+ wait for the stream to advance (as it may never do so).
+
+ This can happen due to older versions of Synapse giving out stream
+ positions without persisting them in the DB, and so on restart the
+ stream would get reset back to an older position.
+ """
+ user = self.register_user("alice", "password")
+
+ # Create a token and arbitrarily advance one of the streams.
+ current_token = self.hs.get_event_sources().get_current_token()
+ since_token = current_token.copy_and_advance(
+ StreamKeyType.PRESENCE, current_token.presence_key + 1
+ )
+
+ sync_d = defer.ensureDeferred(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(user),
+ generate_sync_config(user),
+ sync_version=SyncVersion.SYNC_V2,
+ request_key=generate_request_key(),
+ since_token=since_token,
+ timeout=0,
+ )
+ )
+
+ # We should return without waiting for the presence stream to advance.
+ self.get_success(sync_d)
+
def generate_sync_config(
user_id: str,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index bfb26139d3..12c11f342c 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -1386,10 +1386,12 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# Create a future token that will cause us to wait. Since we never send a new
# event to reach that future stream_ordering, the worker will wait until the
# full timeout.
+ stream_id_gen = self.store.get_events_stream_id_generator()
+ stream_id = self.get_success(stream_id_gen.get_next().__aenter__())
current_token = self.event_sources.get_current_token()
future_position_token = current_token.copy_and_replace(
StreamKeyType.ROOM,
- RoomStreamToken(stream=current_token.room_key.stream + 1),
+ RoomStreamToken(stream=stream_id),
)
future_position_token_serialized = self.get_success(
|