summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/slave/storage/devices.py36
-rw-r--r--synapse/replication/tcp/resource.py9
-rw-r--r--synapse/replication/tcp/streams/__init__.py70
-rw-r--r--synapse/replication/tcp/streams/_base.py246
-rw-r--r--synapse/replication/tcp/streams/federation.py16
5 files changed, 206 insertions, 171 deletions
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
-            for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+            self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
-
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self.get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..6e2ebaf614 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -166,11 +166,6 @@ class ReplicationStreamer(object):
                 self.pending_updates = False
 
                 with Measure(self.clock, "repl.stream.get_updates"):
-                    # First we tell the streams that they should update their
-                    # current tokens.
-                    for stream in self.streams:
-                        stream.advance_current_token()
-
                     all_streams = self.streams
 
                     if self._replication_torture_level is not None:
@@ -180,7 +175,7 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.upto_token:
+                        if stream.last_token == stream.current_token():
                             continue
 
                         if self._replication_torture_level:
@@ -192,7 +187,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.upto_token,
+                            stream.current_token(),
                         )
                         try:
                             updates, current_token = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..29199f5b46 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,27 +24,61 @@ Each stream is defined by the following information:
     current_token:      The function that returns the current token for the stream
     update_function:    The function that returns a list of updates between two tokens
 """
-
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+    AccountDataStream,
+    BackfillStream,
+    CachesStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PublicRoomsStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+    UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
 
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
-        events.EventsStream,
-        _base.BackfillStream,
-        _base.PresenceStream,
-        _base.TypingStream,
-        _base.ReceiptsStream,
-        _base.PushRulesStream,
-        _base.PushersStream,
-        _base.CachesStream,
-        _base.PublicRoomsStream,
-        _base.DeviceListsStream,
-        _base.ToDeviceStream,
-        federation.FederationStream,
-        _base.TagAccountDataStream,
-        _base.AccountDataStream,
-        _base.GroupServerStream,
-        _base.UserSignatureStream,
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        GroupServerStream,
+        UserSignatureStream,
     )
 }
+
+__all__ = [
+    "STREAMS_MAP",
+    "BackfillStream",
+    "PresenceStream",
+    "TypingStream",
+    "ReceiptsStream",
+    "PushRulesStream",
+    "PushersStream",
+    "CachesStream",
+    "PublicRoomsStream",
+    "DeviceListsStream",
+    "ToDeviceStream",
+    "TagAccountDataStream",
+    "AccountDataStream",
+    "GroupServerStream",
+    "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..32d9514883 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -17,111 +17,28 @@
 import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Tuple
 
 import attr
 
+from synapse.types import JsonDict
+
 logger = logging.getLogger(__name__)
 
 
 MAX_EVENTS_BEHIND = 500000
 
-BackfillStreamRow = namedtuple(
-    "BackfillStreamRow",
-    (
-        "event_id",  # str
-        "room_id",  # str
-        "type",  # str
-        "state_key",  # str, optional
-        "redacts",  # str, optional
-        "relates_to",  # str, optional
-    ),
-)
-PresenceStreamRow = namedtuple(
-    "PresenceStreamRow",
-    (
-        "user_id",  # str
-        "state",  # str
-        "last_active_ts",  # int
-        "last_federation_update_ts",  # int
-        "last_user_sync_ts",  # int
-        "status_msg",  # str
-        "currently_active",  # bool
-    ),
-)
-TypingStreamRow = namedtuple(
-    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-)
-ReceiptsStreamRow = namedtuple(
-    "ReceiptsStreamRow",
-    (
-        "room_id",  # str
-        "receipt_type",  # str
-        "user_id",  # str
-        "event_id",  # str
-        "data",  # dict
-    ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
-PushersStreamRow = namedtuple(
-    "PushersStreamRow",
-    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-)
-
-
-@attr.s
-class CachesStreamRow:
-    """Stream to inform workers they should invalidate their cache.
-
-    Attributes:
-        cache_func: Name of the cached function.
-        keys: The entry in the cache to invalidate. If None then will
-            invalidate all.
-        invalidation_ts: Timestamp of when the invalidation took place.
-    """
-
-    cache_func = attr.ib(type=str)
-    keys = attr.ib(type=Optional[List[Any]])
-    invalidation_ts = attr.ib(type=int)
-
-
-PublicRoomsStreamRow = namedtuple(
-    "PublicRoomsStreamRow",
-    (
-        "room_id",  # str
-        "visibility",  # str
-        "appservice_id",  # str, optional
-        "network_id",  # str, optional
-    ),
-)
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
-TagAccountDataStreamRow = namedtuple(
-    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-)
-AccountDataStreamRow = namedtuple(
-    "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
-)
-GroupsStreamRow = namedtuple(
-    "GroupsStreamRow",
-    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-)
-UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
-
 
 class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
     NAME = None  # type: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
-    _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
     def parse_row(cls, row):
@@ -142,26 +59,15 @@ class Stream(object):
         # The token from which we last asked for updates
         self.last_token = self.current_token()
 
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
-
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
-        """
-        self.upto_token = self.current_token()
-
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token()
 
     async def get_updates(self):
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
             Deferred[Tuple[List[Tuple[int, Any]], int]:
@@ -174,44 +80,45 @@ class Stream(object):
 
         return updates, current_token
 
-    async def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, from_token: int
+    ) -> Tuple[List[Tuple[int, JsonDict]], int]:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            Resolves to a pair `(updates, new_last_token)`, where `updates` is
+            a list of `(token, row)` entries and `new_last_token` is the new
+            position in stream.
         """
+
         if from_token in ("NOW", "now"):
-            return [], self.upto_token
+            return [], self.current_token()
 
-        current_token = self.upto_token
+        current_token = self.current_token()
 
         from_token = int(from_token)
 
         if from_token == current_token:
             return [], current_token
 
-        logger.info("get_updates_since: %s", self.__class__)
-        if self._LIMITED:
-            rows = await self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        rows = await self.update_function(
+            from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+        )
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = await self.update_function(from_token, current_token)
+        # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
+        rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
 
         updates = [(row[0], row[1:]) for row in rows]
 
         # check we didn't get more rows than the limit.
         # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
+        if len(updates) >= MAX_EVENTS_BEHIND:
             raise Exception("stream %s has fallen behind" % (self.NAME))
 
+        # The update function didn't hit the limit, so we must have got all
+        # the updates to `current_token`, and can return that as our new
+        # stream position.
         return updates, current_token
 
     def current_token(self):
@@ -223,9 +130,8 @@ class Stream(object):
         """
         raise NotImplementedError()
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    def update_function(self, from_token, current_token, limit):
+        """Get updates between from_token and to_token.
 
         Returns:
             Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -240,6 +146,18 @@ class BackfillStream(Stream):
     or it went from being an outlier to not.
     """
 
+    BackfillStreamRow = namedtuple(
+        "BackfillStreamRow",
+        (
+            "event_id",  # str
+            "room_id",  # str
+            "type",  # str
+            "state_key",  # str, optional
+            "redacts",  # str, optional
+            "relates_to",  # str, optional
+        ),
+    )
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
@@ -252,8 +170,20 @@ class BackfillStream(Stream):
 
 
 class PresenceStream(Stream):
+    PresenceStreamRow = namedtuple(
+        "PresenceStreamRow",
+        (
+            "user_id",  # str
+            "state",  # str
+            "last_active_ts",  # int
+            "last_federation_update_ts",  # int
+            "last_user_sync_ts",  # int
+            "status_msg",  # str
+            "currently_active",  # bool
+        ),
+    )
+
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
@@ -267,8 +197,11 @@ class PresenceStream(Stream):
 
 
 class TypingStream(Stream):
+    TypingStreamRow = namedtuple(
+        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+    )
+
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
@@ -281,6 +214,17 @@ class TypingStream(Stream):
 
 
 class ReceiptsStream(Stream):
+    ReceiptsStreamRow = namedtuple(
+        "ReceiptsStreamRow",
+        (
+            "room_id",  # str
+            "receipt_type",  # str
+            "user_id",  # str
+            "event_id",  # str
+            "data",  # dict
+        ),
+    )
+
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
@@ -297,6 +241,8 @@ class PushRulesStream(Stream):
     """A user has changed their push rules
     """
 
+    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
@@ -317,6 +263,11 @@ class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
 
+    PushersStreamRow = namedtuple(
+        "PushersStreamRow",
+        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+    )
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
@@ -334,6 +285,21 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
+    @attr.s
+    class CachesStreamRow:
+        """Stream to inform workers they should invalidate their cache.
+
+        Attributes:
+            cache_func: Name of the cached function.
+            keys: The entry in the cache to invalidate. If None then will
+                invalidate all.
+            invalidation_ts: Timestamp of when the invalidation took place.
+        """
+
+        cache_func = attr.ib(type=str)
+        keys = attr.ib(type=Optional[List[Any]])
+        invalidation_ts = attr.ib(type=int)
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
@@ -350,6 +316,16 @@ class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
 
+    PublicRoomsStreamRow = namedtuple(
+        "PublicRoomsStreamRow",
+        (
+            "room_id",  # str
+            "visibility",  # str
+            "appservice_id",  # str, optional
+            "network_id",  # str, optional
+        ),
+    )
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
@@ -363,11 +339,15 @@ class PublicRoomsStream(Stream):
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
+    @attr.s
+    class DeviceListsStreamRow:
+        entity = attr.ib(type=str)
+
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
@@ -383,6 +363,8 @@ class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
 
+    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
@@ -399,6 +381,10 @@ class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
 
+    TagAccountDataStreamRow = namedtuple(
+        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+    )
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
@@ -415,6 +401,10 @@ class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
 
+    AccountDataStreamRow = namedtuple(
+        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+    )
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
@@ -440,6 +430,11 @@ class AccountDataStream(Stream):
 
 
 class GroupServerStream(Stream):
+    GroupsStreamRow = namedtuple(
+        "GroupsStreamRow",
+        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+    )
+
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
@@ -456,8 +451,9 @@ class UserSignatureStream(Stream):
     """A user has signed their own device with their user-signing key
     """
 
+    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+
     NAME = "user_signature"
-    _LIMITED = False
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs):
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 615f3dc9ac..f5f9336430 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,20 +17,20 @@ from collections import namedtuple
 
 from ._base import Stream
 
-FederationStreamRow = namedtuple(
-    "FederationStreamRow",
-    (
-        "type",  # str, the type of data as defined in the BaseFederationRows
-        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-    ),
-)
-
 
 class FederationStream(Stream):
     """Data to be sent over federation. Only available when master has federation
     sending disabled.
     """
 
+    FederationStreamRow = namedtuple(
+        "FederationStreamRow",
+        (
+            "type",  # str, the type of data as defined in the BaseFederationRows
+            "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+        ),
+    )
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow