diff options
Diffstat (limited to 'synapse/replication')
-rw-r--r-- | synapse/replication/slave/storage/devices.py | 36 | ||||
-rw-r--r-- | synapse/replication/tcp/resource.py | 9 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/__init__.py | 70 | ||||
-rw-r--r-- | synapse/replication/tcp/streams/_base.py | 79 |
4 files changed, 112 insertions, 82 deletions
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 1c77687eea..23b1650e41 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self.hs = hs self._device_list_id_gen = SlavedIdTracker( - db_conn, "device_lists_stream", "stream_id" + db_conn, + "device_lists_stream", + "stream_id", + extra_tables=[ + ("user_signature_stream", "stream_id"), + ("device_lists_outbound_pokes", "stream_id"), + ], ) device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( @@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto def process_replication_rows(self, stream_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(token) - for row in rows: - self._invalidate_caches_for_devices(token, row.user_id, row.destination) + self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super(SlavedDeviceStore, self).process_replication_rows( stream_name, token, rows ) - def _invalidate_caches_for_devices(self, token, user_id, destination): - self._device_list_stream_cache.entity_has_changed(user_id, token) - - if destination: - self._device_list_federation_stream_cache.entity_has_changed( - destination, token - ) + def _invalidate_caches_for_devices(self, token, rows): + for row in rows: + # The entities are either user IDs (starting with '@') whose devices + # have changed, or remote servers that we need to tell about + # changes. + if row.entity.startswith("@"): + self._device_list_stream_cache.entity_has_changed(row.entity, token) + self.get_cached_devices_for_user.invalidate((row.entity,)) + self._get_cached_user_device.invalidate_many((row.entity,)) + self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) - self.get_cached_devices_for_user.invalidate((user_id,)) - self._get_cached_user_device.invalidate_many((user_id,)) - self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) + else: + self._device_list_federation_stream_cache.entity_has_changed( + row.entity, token + ) diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index ce9d1fae12..6e2ebaf614 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -166,11 +166,6 @@ class ReplicationStreamer(object): self.pending_updates = False with Measure(self.clock, "repl.stream.get_updates"): - # First we tell the streams that they should update their - # current tokens. - for stream in self.streams: - stream.advance_current_token() - all_streams = self.streams if self._replication_torture_level is not None: @@ -180,7 +175,7 @@ class ReplicationStreamer(object): random.shuffle(all_streams) for stream in all_streams: - if stream.last_token == stream.upto_token: + if stream.last_token == stream.current_token(): continue if self._replication_torture_level: @@ -192,7 +187,7 @@ class ReplicationStreamer(object): "Getting stream: %s: %s -> %s", stream.NAME, stream.last_token, - stream.upto_token, + stream.current_token(), ) try: updates, current_token = await stream.get_updates() diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 5f52264e84..29199f5b46 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -24,27 +24,61 @@ Each stream is defined by the following information: current_token: The function that returns the current token for the stream update_function: The function that returns a list of updates between two tokens """ - -from . import _base, events, federation +from synapse.replication.tcp.streams._base import ( + AccountDataStream, + BackfillStream, + CachesStream, + DeviceListsStream, + GroupServerStream, + PresenceStream, + PublicRoomsStream, + PushersStream, + PushRulesStream, + ReceiptsStream, + TagAccountDataStream, + ToDeviceStream, + TypingStream, + UserSignatureStream, +) +from synapse.replication.tcp.streams.events import EventsStream +from synapse.replication.tcp.streams.federation import FederationStream STREAMS_MAP = { stream.NAME: stream for stream in ( - events.EventsStream, - _base.BackfillStream, - _base.PresenceStream, - _base.TypingStream, - _base.ReceiptsStream, - _base.PushRulesStream, - _base.PushersStream, - _base.CachesStream, - _base.PublicRoomsStream, - _base.DeviceListsStream, - _base.ToDeviceStream, - federation.FederationStream, - _base.TagAccountDataStream, - _base.AccountDataStream, - _base.GroupServerStream, - _base.UserSignatureStream, + EventsStream, + BackfillStream, + PresenceStream, + TypingStream, + ReceiptsStream, + PushRulesStream, + PushersStream, + CachesStream, + PublicRoomsStream, + DeviceListsStream, + ToDeviceStream, + FederationStream, + TagAccountDataStream, + AccountDataStream, + GroupServerStream, + UserSignatureStream, ) } + +__all__ = [ + "STREAMS_MAP", + "BackfillStream", + "PresenceStream", + "TypingStream", + "ReceiptsStream", + "PushRulesStream", + "PushersStream", + "CachesStream", + "PublicRoomsStream", + "DeviceListsStream", + "ToDeviceStream", + "TagAccountDataStream", + "AccountDataStream", + "GroupServerStream", + "UserSignatureStream", +] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 208e8a667b..abf5c6c6a8 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -17,10 +17,12 @@ import itertools import logging from collections import namedtuple -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import attr +from synapse.types import JsonDict + logger = logging.getLogger(__name__) @@ -94,9 +96,13 @@ PublicRoomsStreamRow = namedtuple( "network_id", # str, optional ), ) -DeviceListsStreamRow = namedtuple( - "DeviceListsStreamRow", ("user_id", "destination") # str # str -) + + +@attr.s +class DeviceListsStreamRow: + entity = attr.ib(type=str) + + ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str TagAccountDataStreamRow = namedtuple( "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict @@ -115,13 +121,12 @@ class Stream(object): """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last - time it was called up until the point `advance_current_token` was called. + time it was called. """ NAME = None # type: str # The name of the stream # The type of the row. Used by the default impl of parse_row. ROW_TYPE = None # type: Any - _LIMITED = True # Whether the update function takes a limit @classmethod def parse_row(cls, row): @@ -142,26 +147,15 @@ class Stream(object): # The token from which we last asked for updates self.last_token = self.current_token() - # The token that we will get updates up to - self.upto_token = self.current_token() - - def advance_current_token(self): - """Updates `upto_token` to "now", which updates up until which point - get_updates[_since] will fetch rows till. - """ - self.upto_token = self.current_token() - def discard_updates_and_advance(self): """Called when the stream should advance but the updates would be discarded, e.g. when there are no currently connected workers. """ - self.upto_token = self.current_token() - self.last_token = self.upto_token + self.last_token = self.current_token() async def get_updates(self): """Gets all updates since the last time this function was called (or - since the stream was constructed if it hadn't been called before), - until the `upto_token` + since the stream was constructed if it hadn't been called before). Returns: Deferred[Tuple[List[Tuple[int, Any]], int]: @@ -174,44 +168,45 @@ class Stream(object): return updates, current_token - async def get_updates_since(self, from_token): + async def get_updates_since( + self, from_token: int + ) -> Tuple[List[Tuple[int, JsonDict]], int]: """Like get_updates except allows specifying from when we should stream updates Returns: - Deferred[Tuple[List[Tuple[int, Any]], int]: - Resolves to a pair ``(updates, current_token)``, where ``updates`` is a - list of ``(token, row)`` entries. ``row`` will be json-serialised and - sent over the replication steam. + Resolves to a pair `(updates, new_last_token)`, where `updates` is + a list of `(token, row)` entries and `new_last_token` is the new + position in stream. """ + if from_token in ("NOW", "now"): - return [], self.upto_token + return [], self.current_token() - current_token = self.upto_token + current_token = self.current_token() from_token = int(from_token) if from_token == current_token: return [], current_token - logger.info("get_updates_since: %s", self.__class__) - if self._LIMITED: - rows = await self.update_function( - from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 - ) + rows = await self.update_function( + from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 + ) - # never turn more than MAX_EVENTS_BEHIND + 1 into updates. - rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) - else: - rows = await self.update_function(from_token, current_token) + # never turn more than MAX_EVENTS_BEHIND + 1 into updates. + rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1) updates = [(row[0], row[1:]) for row in rows] # check we didn't get more rows than the limit. # doing it like this allows the update_function to be a generator. - if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND: + if len(updates) >= MAX_EVENTS_BEHIND: raise Exception("stream %s has fallen behind" % (self.NAME)) + # The update function didn't hit the limit, so we must have got all + # the updates to `current_token`, and can return that as our new + # stream position. return updates, current_token def current_token(self): @@ -223,9 +218,8 @@ class Stream(object): """ raise NotImplementedError() - def update_function(self, from_token, current_token, limit=None): - """Get updates between from_token and to_token. If Stream._LIMITED is - True then limit is provided, otherwise it's not. + def update_function(self, from_token, current_token, limit): + """Get updates between from_token and to_token. Returns: Deferred(list(tuple)): the first entry in the tuple is the token for @@ -253,7 +247,6 @@ class BackfillStream(Stream): class PresenceStream(Stream): NAME = "presence" - _LIMITED = False ROW_TYPE = PresenceStreamRow def __init__(self, hs): @@ -268,7 +261,6 @@ class PresenceStream(Stream): class TypingStream(Stream): NAME = "typing" - _LIMITED = False ROW_TYPE = TypingStreamRow def __init__(self, hs): @@ -363,11 +355,11 @@ class PublicRoomsStream(Stream): class DeviceListsStream(Stream): - """Someone added/changed/removed a device + """Either a user has updated their devices or a remote server needs to be + told about a device update. """ NAME = "device_lists" - _LIMITED = False ROW_TYPE = DeviceListsStreamRow def __init__(self, hs): @@ -457,7 +449,6 @@ class UserSignatureStream(Stream): """ NAME = "user_signature" - _LIMITED = False ROW_TYPE = UserSignatureStreamRow def __init__(self, hs): |