diff options
author | Erik Johnston <erik@matrix.org> | 2020-05-07 13:51:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-07 13:51:08 +0100 |
commit | d7983b63a6746d92225295f1e9d521f847cf8ba7 (patch) | |
tree | 13d581210d94c26bd75036592996f6f53f7d4bb2 /synapse/replication | |
parent | Merge pull request #7398 from Starbix/alpine-3.11 (diff) | |
download | synapse-d7983b63a6746d92225295f1e9d521f847cf8ba7.tar.xz |
Support any process writing to cache invalidation stream. (#7436)
Diffstat (limited to 'synapse/replication')
-rw-r--r-- | synapse/replication/slave/storage/_base.py | 50 | ||||
-rw-r--r-- | synapse/replication/slave/storage/account_data.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/deviceinbox.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/devices.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/events.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/groups.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/presence.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/push_rule.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/pushers.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/receipts.py | 6 | ||||
-rw-r--r-- | synapse/replication/slave/storage/room.py | 4 | ||||
-rw-r--r-- | synapse/replication/tcp/client.py | 6 | ||||
-rw-r--r-- | synapse/replication/tcp/commands.py | 33 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 42 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 22 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 87 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/events.py | 4 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/federation.py | 12 |
18 files changed, 131 insertions, 183 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 5d7c8871a4..2904bd0235 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -18,14 +18,10 @@ from typing import Optional import six -from synapse.storage.data_stores.main.cache import ( - CURRENT_STATE_CACHE_NAME, - CacheInvalidationWorkerStore, -) +from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine - -from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) @@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): - self._cache_id_gen = SlavedIdTracker( - db_conn, "cache_invalidation_stream", "stream_id" - ) # type: Optional[SlavedIdTracker] + self._cache_id_gen = MultiWriterIdGenerator( + db_conn, + database, + instance_name=hs.get_instance_name(), + table="cache_invalidation_stream_by_instance", + instance_column="instance_name", + id_column="stream_id", + sequence_name="cache_invalidation_stream_seq", + ) # type: Optional[MultiWriterIdGenerator] else: self._cache_id_gen = None self.hs = hs - - def get_cache_stream_token(self): - if self._cache_id_gen: - return self._cache_id_gen.get_current_token() - else: - return 0 - - def process_replication_rows(self, stream_name, token, rows): - if stream_name == "caches": - if self._cache_id_gen: - self._cache_id_gen.advance(token) - for row in rows: - if row.cache_func == CURRENT_STATE_CACHE_NAME: - if row.keys is None: - raise Exception( - "Can't send an 'invalidate all' for current state cache" - ) - - room_id = row.keys[0] - members_changed = set(row.keys[1:]) - self._invalidate_state_caches(room_id, members_changed) - else: - self._attempt_to_invalidate_cache(row.cache_func, row.keys) - - def _invalidate_cache_and_stream(self, txn, cache_func, keys): - txn.call_after(cache_func.invalidate, keys) - txn.call_after(self._send_invalidation_poke, cache_func, keys) - - def _send_invalidation_poke(self, cache_func, keys): - self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys) diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 65e54b1c71..2a4f5c7cfd 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved def get_max_account_data_stream_id(self): return self._account_data_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "tag_account_data": self._account_data_id_gen.advance(token) for row in rows: @@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedAccountDataStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index c923751e50..6e7fd259d4 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): expiry_ms=30 * 60 * 1000, ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "to_device": self._device_inbox_id_gen.advance(token) for row in rows: @@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): self._device_federation_outbox_stream_cache.entity_has_changed( row.entity, token ) - return super(SlavedDeviceInboxStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 58fb0eaae3..9d8067342f 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) self._invalidate_caches_for_devices(token, rows) @@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedDeviceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _invalidate_caches_for_devices(self, token, rows): for row in rows: diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 15011259df..b313720a4b 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -93,7 +93,7 @@ class SlavedEventStore( def get_room_min_stream_ordering(self): return self._backfill_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "events": self._stream_id_gen.advance(token) for row in rows: @@ -111,9 +111,7 @@ class SlavedEventStore( row.relates_to, backfilled=True, ) - return super(SlavedEventStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) def _process_event_stream_row(self, token, row): data = row.data diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py index 01bcf0e882..1851e7d525 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py @@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): def get_group_stream_token(self): return self._group_updates_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "groups": self._group_updates_id_gen.advance(token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedGroupServerStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index fae3125072..bd79ba99be 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore): def get_current_presence_token(self): return self._presence_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "presence": self._presence_id_gen.advance(token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) - return super(SlavedPresenceStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 6138796da4..5d5816d7eb 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "push_rules": self._push_rules_stream_id_gen.advance(token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) self.push_rules_stream_cache.entity_has_changed(row.user_id, token) - return super(SlavedPushRuleStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 67be337945..cb78b49acb 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): def get_pushers_stream_token(self): return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "pushers": self._pushers_id_gen.advance(token) - return super(SlavedPusherStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 993432edcb..be716cc558 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) self.get_receipts_for_room.invalidate((room_id, receipt_type)) - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "receipts": self._receipts_id_gen.advance(token) for row in rows: @@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): ) self._receipts_stream_cache.entity_has_changed(row.room_id, token) - return super(SlavedReceiptsStore, self).process_replication_rows( - stream_name, token, rows - ) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py index 10dda8708f..8873bf37e5 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py @@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - def process_replication_rows(self, stream_name, token, rows): + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == "public_rooms": self._public_room_id_gen.advance(token) - return super(RoomStore, self).process_replication_rows(stream_name, token, rows) + return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 3bbf3c3569..20cb8a654f 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -100,10 +100,10 @@ class ReplicationDataHandler: token: stream token for this batch of rows rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - self.store.process_replication_rows(stream_name, token, rows) + self.store.process_replication_rows(stream_name, instance_name, token, rows) - async def on_position(self, stream_name: str, token: int): - self.store.process_replication_rows(stream_name, token, []) + async def on_position(self, stream_name: str, instance_name: str, token: int): + self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index f58e384d17..c04f622816 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -341,37 +341,6 @@ class RemovePusherCommand(Command): return " ".join((self.app_id, self.push_key, self.user_id)) -class InvalidateCacheCommand(Command): - """Sent by the client to invalidate an upstream cache. - - THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE - NOT DISASTROUS IF WE DROP ON THE FLOOR. - - Mainly used to invalidate destination retry timing caches. - - Format:: - - INVALIDATE_CACHE <cache_func> <keys_json> - - Where <keys_json> is a json list. - """ - - NAME = "INVALIDATE_CACHE" - - def __init__(self, cache_func, keys): - self.cache_func = cache_func - self.keys = keys - - @classmethod - def from_line(cls, line): - cache_func, keys_json = line.split(" ", 1) - - return cls(cache_func, json.loads(keys_json)) - - def to_line(self): - return " ".join((self.cache_func, _json_encoder.encode(self.keys))) - - class UserIpCommand(Command): """Sent periodically when a worker sees activity from a client. @@ -439,7 +408,6 @@ _COMMANDS = ( UserSyncCommand, FederationAckCommand, RemovePusherCommand, - InvalidateCacheCommand, UserIpCommand, RemoteServerUpCommand, ClearUserSyncsCommand, @@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = ( ClearUserSyncsCommand.NAME, FederationAckCommand.NAME, RemovePusherCommand.NAME, - InvalidateCacheCommand.NAME, UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index b14a3d9fca..7c5d6c76e7 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,18 +15,7 @@ # limitations under the License. import logging -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Set, - Tuple, - TypeVar, -) +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar from prometheus_client import Counter @@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, FederationAckCommand, - InvalidateCacheCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -171,7 +159,7 @@ class ReplicationCommandHandler: return for stream_name, stream in self._streams.items(): - current_token = stream.current_token() + current_token = stream.current_token(self._instance_name) self.send_command( PositionCommand(stream_name, self._instance_name, current_token) ) @@ -210,18 +198,6 @@ class ReplicationCommandHandler: self._notifier.on_new_replication_data() - async def on_INVALIDATE_CACHE( - self, conn: AbstractConnection, cmd: InvalidateCacheCommand - ): - invalidate_cache_counter.inc() - - if self._is_master: - # We invalidate the cache locally, but then also stream that to other - # workers. - await self._store.invalidate_cache_and_stream( - cmd.cache_func, tuple(cmd.keys) - ) - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): user_ip_cache_counter.inc() @@ -295,7 +271,7 @@ class ReplicationCommandHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ - logger.debug("Received rdata %s -> %s", stream_name, token) + logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token) await self._replication_data_handler.on_rdata( stream_name, instance_name, token, rows ) @@ -326,7 +302,7 @@ class ReplicationCommandHandler: self._pending_batches.pop(stream_name, []) # Find where we previously streamed up to. - current_token = stream.current_token() + current_token = stream.current_token(cmd.instance_name) # If the position token matches our current token then we're up to # date and there's nothing to do. Otherwise, fetch all updates @@ -363,7 +339,9 @@ class ReplicationCommandHandler: logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position(stream_name, cmd.token) + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) self._streams_by_connection.setdefault(conn, set()).add(stream_name) @@ -491,12 +469,6 @@ class ReplicationCommandHandler: cmd = RemovePusherCommand(app_id, push_key, user_id) self.send_command(cmd) - def send_invalidate_cache(self, cache_func: Callable, keys: tuple): - """Poke the master to invalidate a cache. - """ - cmd = InvalidateCacheCommand(cache_func.__name__, keys) - self.send_command(cmd) - def send_user_ip( self, user_id: str, diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index b690abedad..002171ce7c 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol -from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream +from synapse.replication.tcp.streams import ( + STREAMS_MAP, + CachesStream, + FederationStream, + Stream, +) from synapse.util.metrics import Measure stream_updates_counter = Counter( @@ -71,11 +76,16 @@ class ReplicationStreamer(object): self.store = hs.get_datastore() self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self._instance_name = hs.get_instance_name() self._replication_torture_level = hs.config.replication_torture_level # Work out list of streams that this instance is the source of. self.streams = [] # type: List[Stream] + + # All workers can write to the cache invalidation stream. + self.streams.append(CachesStream(hs)) + if hs.config.worker_app is None: for stream in STREAMS_MAP.values(): if stream == FederationStream and hs.config.send_federation: @@ -83,6 +93,10 @@ class ReplicationStreamer(object): # has been disabled on the master. continue + if stream == CachesStream: + # We've already added it above. + continue + self.streams.append(stream(hs)) self.streams_by_name = {stream.NAME: stream for stream in self.streams} @@ -145,7 +159,9 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.current_token(): + if stream.last_token == stream.current_token( + self._instance_name + ): continue if self._replication_torture_level: @@ -157,7 +173,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.current_token(), + stream.current_token(self._instance_name), ) try: updates, current_token, limited = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 084604e8b0..b48a6a3e91 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -95,20 +95,25 @@ class Stream(object): def __init__( self, local_instance_name: str, - current_token_function: Callable[[], Token], + current_token_function: Callable[[str], Token], update_function: UpdateFunction, ): """Instantiate a Stream - current_token_function and update_function are callbacks which should be - implemented by subclasses. + `current_token_function` and `update_function` are callbacks which + should be implemented by subclasses. - current_token_function is called to get the current token of the underlying - stream. It is only meaningful on the process that is the source of the - replication stream (ie, usually the master). + `current_token_function` takes an instance name, which is a writer to + the stream, and returns the position in the stream of the writer (as + viewed from the current process). On the writer process this is where + the writer has successfully written up to, whereas on other processes + this is the position which we have received updates up to over + replication. (Note that most streams have a single writer and so their + implementations ignore the instance name passed in). - update_function is called to get updates for this stream between a pair of - stream tokens. See the UpdateFunction type definition for more info. + `update_function` is called to get updates for this stream between a + pair of stream tokens. See the `UpdateFunction` type definition for more + info. Args: local_instance_name: The instance name of the current process @@ -120,13 +125,13 @@ class Stream(object): self.update_function = update_function # The token from which we last asked for updates - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.last_token = self.current_token() + self.last_token = self.current_token(self.local_instance_name) async def get_updates(self) -> StreamUpdateResult: """Gets all updates since the last time this function was called (or @@ -138,7 +143,7 @@ class Stream(object): position in stream, and `limited` is whether there are more updates to fetch. """ - current_token = self.current_token() + current_token = self.current_token(self.local_instance_name) updates, current_token, limited = await self.get_updates_since( self.local_instance_name, self.last_token, current_token ) @@ -170,6 +175,16 @@ class Stream(object): return updates, upto_token, limited +def current_token_without_instance( + current_token: Callable[[], int] +) -> Callable[[str], int]: + """Takes a current token callback function for a single writer stream + that doesn't take an instance name parameter and wraps it in a function that + does accept an instance name parameter but ignores it. + """ + return lambda instance_name: current_token() + + def db_query_to_update_function( query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] ) -> UpdateFunction: @@ -235,7 +250,7 @@ class BackfillStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_backfill_token, + current_token_without_instance(store.get_current_backfill_token), db_query_to_update_function(store.get_all_new_backfill_event_rows), ) @@ -271,7 +286,9 @@ class PresenceStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), store.get_current_presence_token, update_function + hs.get_instance_name(), + current_token_without_instance(store.get_current_presence_token), + update_function, ) @@ -296,7 +313,9 @@ class TypingStream(Stream): update_function = make_http_update_function(hs, self.NAME) super().__init__( - hs.get_instance_name(), typing_handler.get_current_token, update_function + hs.get_instance_name(), + current_token_without_instance(typing_handler.get_current_token), + update_function, ) @@ -319,7 +338,7 @@ class ReceiptsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_receipt_stream_id, + current_token_without_instance(store.get_max_receipt_stream_id), db_query_to_update_function(store.get_all_updated_receipts), ) @@ -339,7 +358,7 @@ class PushRulesStream(Stream): hs.get_instance_name(), self._current_token, self._update_function ) - def _current_token(self) -> int: + def _current_token(self, instance_name: str) -> int: push_rules_token, _ = self.store.get_push_rules_stream_token() return push_rules_token @@ -373,7 +392,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), - store.get_pushers_stream_token, + current_token_without_instance(store.get_pushers_stream_token), db_query_to_update_function(store.get_all_updated_pushers_rows), ) @@ -402,12 +421,26 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - store = hs.get_datastore() + self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_cache_stream_token, - db_query_to_update_function(store.get_all_updated_caches), + self.store.get_cache_stream_token, + self._update_function, + ) + + async def _update_function( + self, instance_name: str, from_token: int, upto_token: int, limit: int + ): + rows = await self.store.get_all_updated_caches( + instance_name, from_token, upto_token, limit ) + updates = [(row[0], row[1:]) for row in rows] + limited = False + if len(updates) >= limit: + upto_token = updates[-1][0] + limited = True + + return updates, upto_token, limited class PublicRoomsStream(Stream): @@ -431,7 +464,7 @@ class PublicRoomsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_current_public_room_stream_id, + current_token_without_instance(store.get_current_public_room_stream_id), db_query_to_update_function(store.get_all_new_public_rooms), ) @@ -452,7 +485,7 @@ class DeviceListsStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function(store.get_all_device_list_changes_for_remotes), ) @@ -470,7 +503,7 @@ class ToDeviceStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_to_device_stream_token, + current_token_without_instance(store.get_to_device_stream_token), db_query_to_update_function(store.get_all_new_device_messages), ) @@ -490,7 +523,7 @@ class TagAccountDataStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_max_account_data_stream_id, + current_token_without_instance(store.get_max_account_data_stream_id), db_query_to_update_function(store.get_all_updated_tags), ) @@ -510,7 +543,7 @@ class AccountDataStream(Stream): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_max_account_data_stream_id, + current_token_without_instance(self.store.get_max_account_data_stream_id), db_query_to_update_function(self._update_function), ) @@ -541,7 +574,7 @@ class GroupServerStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_group_stream_token, + current_token_without_instance(store.get_group_stream_token), db_query_to_update_function(store.get_all_groups_changes), ) @@ -559,7 +592,7 @@ class UserSignatureStream(Stream): store = hs.get_datastore() super().__init__( hs.get_instance_name(), - store.get_device_stream_token, + current_token_without_instance(store.get_device_stream_token), db_query_to_update_function( store.get_all_user_signature_changes_for_remotes ), diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py index 890e75d827..f370390331 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py @@ -20,7 +20,7 @@ from typing import List, Tuple, Type import attr -from ._base import Stream, StreamUpdateResult, Token +from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance """Handling of the 'events' replication stream @@ -119,7 +119,7 @@ class EventsStream(Stream): self._store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self._store.get_current_events_token, + current_token_without_instance(self._store.get_current_events_token), self._update_function, ) diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py index b0505b8a2c..9bcd13b009 100644 --- a/synapse/replication/tcp/streams/federation.py +++ b/synapse/replication/tcp/streams/federation.py @@ -15,7 +15,11 @@ # limitations under the License. from collections import namedtuple -from synapse.replication.tcp.streams._base import Stream, make_http_update_function +from synapse.replication.tcp.streams._base import ( + Stream, + current_token_without_instance, + make_http_update_function, +) class FederationStream(Stream): @@ -41,7 +45,9 @@ class FederationStream(Stream): # will be a real FederationSender, which has stubs for current_token and # get_replication_rows.) federation_sender = hs.get_federation_sender() - current_token = federation_sender.get_current_token + current_token = current_token_without_instance( + federation_sender.get_current_token + ) update_function = federation_sender.get_replication_rows elif hs.should_send_federation(): @@ -58,7 +64,7 @@ class FederationStream(Stream): super().__init__(hs.get_instance_name(), current_token, update_function) @staticmethod - def _stub_current_token(): + def _stub_current_token(instance_name: str) -> int: # dummy current-token method for use on workers return 0 |