diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c463b60b26..5e137dbbf7 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -51,6 +51,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
+from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -260,6 +261,9 @@ class MockHomeserver:
def should_send_federation(self) -> bool:
return False
+ def get_replication_notifier(self) -> ReplicationNotifier:
+ return ReplicationNotifier()
+
class Porter:
def __init__(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eca75f1108..e386f77de6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -27,6 +27,7 @@ from typing import (
Iterable,
List,
Optional,
+ Set,
Tuple,
Union,
)
@@ -171,12 +172,23 @@ class FederationHandler:
self.third_party_event_rules = hs.get_third_party_event_rules()
+ # Tracks running partial state syncs by room ID.
+ # Partial state syncs currently only run on the main process, so it's okay to
+ # track them in-memory for now.
+ self._active_partial_state_syncs: Set[str] = set()
+ # Tracks partial state syncs we may want to restart.
+ # A dictionary mapping room IDs to (initial destination, other destinations)
+ # tuples.
+ self._partial_state_syncs_maybe_needing_restart: Dict[
+ str, Tuple[Optional[str], Collection[str]]
+ ] = {}
+
# if this is the main process, fire off a background process to resume
# any partial-state-resync operations which were in flight when we
# were shut down.
if not hs.config.worker.worker_app:
run_as_background_process(
- "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+ "resume_sync_partial_state_room", self._resume_partial_state_room_sync
)
@trace
@@ -679,9 +691,7 @@ class FederationHandler:
if ret.partial_state:
# Kick off the process of asynchronously fetching the state for this
# room.
- run_as_background_process(
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
+ self._start_partial_state_room_sync(
initial_destination=origin,
other_destinations=ret.servers_in_room,
room_id=room_id,
@@ -1660,20 +1670,100 @@ class FederationHandler:
# well.
return None
- async def _resume_sync_partial_state_room(self) -> None:
+ async def _resume_partial_state_room_sync(self) -> None:
"""Resumes resyncing of all partial-state rooms after a restart."""
assert not self.config.worker.worker_app
partial_state_rooms = await self.store.get_partial_state_room_resync_info()
for room_id, resync_info in partial_state_rooms.items():
- run_as_background_process(
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
+ self._start_partial_state_room_sync(
initial_destination=resync_info.joined_via,
other_destinations=resync_info.servers_in_room,
room_id=room_id,
)
+ def _start_partial_state_room_sync(
+ self,
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ """Starts the background process to resync the state of a partial state room,
+ if it is not already running.
+
+ Args:
+ initial_destination: the initial homeserver to pull the state from
+ other_destinations: other homeservers to try to pull the state from, if
+ `initial_destination` is unavailable
+ room_id: room to be resynced
+ """
+
+ async def _sync_partial_state_room_wrapper() -> None:
+ if room_id in self._active_partial_state_syncs:
+ # Another local user has joined the room while there is already a
+ # partial state sync running. This implies that there is a new join
+ # event to un-partial state. We might find ourselves in one of a few
+ # scenarios:
+ # 1. There is an existing partial state sync. The partial state sync
+ # un-partial states the new join event before completing and all is
+ # well.
+ # 2. Before the latest join, the homeserver was no longer in the room
+ # and there is an existing partial state sync from our previous
+ # membership of the room. The partial state sync may have:
+ # a) succeeded, but not yet terminated. The room will not be
+ # un-partial stated again unless we restart the partial state
+ # sync.
+ # b) failed, because we were no longer in the room and remote
+ # homeservers were refusing our requests, but not yet
+ # terminated. After the latest join, remote homeservers may
+ # start answering our requests again, so we should restart the
+ # partial state sync.
+ # In the cases where we would want to restart the partial state sync,
+ # the room would have the partial state flag when the partial state sync
+ # terminates.
+ self._partial_state_syncs_maybe_needing_restart[room_id] = (
+ initial_destination,
+ other_destinations,
+ )
+ return
+
+ self._active_partial_state_syncs.add(room_id)
+
+ try:
+ await self._sync_partial_state_room(
+ initial_destination=initial_destination,
+ other_destinations=other_destinations,
+ room_id=room_id,
+ )
+ finally:
+ # Read the room's partial state flag while we still hold the claim to
+ # being the active partial state sync (so that another partial state
+ # sync can't come along and mess with it under us).
+ # Normally, the partial state flag will be gone. If it isn't, then we
+ # may find ourselves in scenario 2a or 2b as described in the comment
+ # above, where we want to restart the partial state sync.
+ is_still_partial_state_room = await self.store.is_partial_state_room(
+ room_id
+ )
+ self._active_partial_state_syncs.remove(room_id)
+
+ if room_id in self._partial_state_syncs_maybe_needing_restart:
+ (
+ restart_initial_destination,
+ restart_other_destinations,
+ ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
+
+ if is_still_partial_state_room:
+ self._start_partial_state_room_sync(
+ initial_destination=restart_initial_destination,
+ other_destinations=restart_other_destinations,
+ room_id=room_id,
+ )
+
+ run_as_background_process(
+ desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
+ )
+
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 26b97cf766..28f0d4a25a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -226,8 +226,7 @@ class Notifier:
self.store = hs.get_datastores().main
self.pending_new_room_events: List[_PendingRoomEventEntry] = []
- # Called when there are new things to stream over replication
- self.replication_callbacks: List[Callable[[], None]] = []
+ self._replication_notifier = hs.get_replication_notifier()
self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
self._federation_client = hs.get_federation_http_client()
@@ -279,7 +278,7 @@ class Notifier:
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
- self.replication_callbacks.append(cb)
+ self._replication_notifier.add_replication_callback(cb)
def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
"""Add a callback that will be called when a user joins a room.
@@ -741,8 +740,7 @@ class Notifier:
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
- for cb in self.replication_callbacks:
- cb()
+ self._replication_notifier.notify_replication()
def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
for cb in self._new_join_in_room_callbacks:
@@ -759,3 +757,26 @@ class Notifier:
# Tell the federation client about the fact the server is back up, so
# that any in flight requests can be immediately retried.
self._federation_client.wake_destination(server)
+
+
+@attr.s(auto_attribs=True)
+class ReplicationNotifier:
+ """Tracks callbacks for things that need to know about stream changes.
+
+ This is separate from the notifier to avoid circular dependencies.
+ """
+
+ _replication_callbacks: List[Callable[[], None]] = attr.Factory(list)
+
+ def add_replication_callback(self, cb: Callable[[], None]) -> None:
+ """Add a callback that will be called when some new data is available.
+ Callback is not given any arguments. It should *not* return a Deferred - if
+ it needs to do any asynchronous work, a background thread should be started and
+ wrapped with run_as_background_process.
+ """
+ self._replication_callbacks.append(cb)
+
+ def notify_replication(self) -> None:
+ """Notify the any replication listeners that there's a new event"""
+ for cb in self._replication_callbacks:
+ cb()
diff --git a/synapse/server.py b/synapse/server.py
index f4ab94c4f3..9d6d268f49 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
-from synapse.notifier import Notifier
+from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
@@ -390,6 +390,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return Notifier(self)
@cache_in_self
+ def get_replication_notifier(self) -> ReplicationNotifier:
+ return ReplicationNotifier()
+
+ @cache_in_self
def get_auth(self) -> Auth:
return Auth(self)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 881d7089db..8a359d7eb8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="account_data",
instance_name=self._instance_name,
tables=[
@@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
# SQLite).
self._account_data_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2179a8bf59..5b66431691 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
+ notifier=hs.get_replication_notifier(),
stream_name="caches",
instance_name=hs.get_instance_name(),
tables=[
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 713be91c5d..8e61aba454 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="to_device",
instance_name=self._instance_name,
tables=[("device_inbox", "instance_name", "stream_id")],
@@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
else:
self._can_write_to_device = True
self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_inbox", "stream_id"
+ db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index cd186c8472..903606fb46 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
# class below that is used on the main process.
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"device_lists_stream",
"stream_id",
extra_tables=[
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4c691642e2..c4ac6c33ba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
super().__init__(database, db_conn, hs)
self._cross_signing_id_gen = StreamIdGenerator(
- db_conn, "e2e_cross_signing_keys", "stream_id"
+ db_conn,
+ hs.get_replication_notifier(),
+ "e2e_cross_signing_keys",
+ "stream_id",
)
async def set_e2e_device_keys(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d150fa8a94..d8a8bcafb6 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore):
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="events",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore):
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="backfill",
instance_name=hs.get_instance_name(),
tables=[("events", "instance_name", "stream_ordering")],
@@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
self._stream_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"events",
"stream_ordering",
is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"events",
"stream_ordering",
step=-1,
@@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore):
self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_event_stream",
instance_name=hs.get_instance_name(),
tables=[
@@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore):
)
else:
self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
- db_conn, "un_partial_stated_event_stream", "stream_id"
+ db_conn,
+ hs.get_replication_notifier(),
+ "un_partial_stated_event_stream",
+ "stream_id",
)
def get_un_partial_stated_events_token(self) -> int:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 7b60815043..beb210f8ee 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
self._presence_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="presence_stream",
instance_name=self._instance_name,
tables=[("presence_stream", "instance_name", "stream_id")],
@@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
)
else:
self._presence_id_gen = StreamIdGenerator(
- db_conn, "presence_stream", "stream_id"
+ db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
)
self.hs = hs
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 03182887d1..14ca167b34 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -118,6 +118,7 @@ class PushRulesWorkerStore(
# class below that is used on the main process.
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"push_rules_stream",
"stream_id",
is_writer=hs.config.worker.worker_app is None,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7f24a3b6ec..df53e726e6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
# class below that is used on the main process.
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 86f5bce5f0..3468f354e6 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="receipts",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
@@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
# SQLite).
self._receipts_id_gen = StreamIdGenerator(
db_conn,
+ hs.get_replication_notifier(),
"receipts_linearized",
"stream_id",
is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 78906a5e1d..7264a33cd4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
+ notifier=hs.get_replication_notifier(),
stream_name="un_partial_stated_room_stream",
instance_name=self._instance_name,
tables=[
@@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
)
else:
self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
- db_conn, "un_partial_stated_room_stream", "stream_id"
+ db_conn,
+ hs.get_replication_notifier(),
+ "un_partial_stated_room_stream",
+ "stream_id",
)
async def store_room(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8670ffbfa3..9adff3f4f5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType
from typing import (
+ TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
@@ -49,6 +50,9 @@ from synapse.storage.database import (
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
+if TYPE_CHECKING:
+ from synapse.notifier import ReplicationNotifier
+
logger = logging.getLogger(__name__)
@@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
+ notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
@@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ self._notifier = notifier
+
def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
@@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
with self._lock:
self._unfinished_ids.pop(next_id)
+ self._notifier.notify_replication()
+
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
+ self._notifier.notify_replication()
+
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
@@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
+ notifier: "ReplicationNotifier",
stream_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
@@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive: bool = True,
) -> None:
self._db = db
+ self._notifier = notifier
self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
@@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
- return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
+ return cast(
+ AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
+ )
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this
@@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
raise Exception("Tried to allocate stream ID on non-writer")
# Cast safety: see get_next.
- return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
+ return cast(
+ AsyncContextManager[List[int]],
+ _MultiWriterCtxManager(self, self._notifier, n),
+ )
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
@@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
@@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""
id_gen: MultiWriterIdGenerator
+ notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)
@@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
+ self.notifier.notify_replication()
+
if exc_type is not None:
return False
|