summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/slave/storage/devices.py36
-rw-r--r--synapse/replication/tcp/resource.py9
-rw-r--r--synapse/replication/tcp/streams/__init__.py70
-rw-r--r--synapse/replication/tcp/streams/_base.py79
4 files changed, 112 insertions, 82 deletions
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
-            for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+            self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
-
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self.get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..6e2ebaf614 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -166,11 +166,6 @@ class ReplicationStreamer(object):
                 self.pending_updates = False
 
                 with Measure(self.clock, "repl.stream.get_updates"):
-                    # First we tell the streams that they should update their
-                    # current tokens.
-                    for stream in self.streams:
-                        stream.advance_current_token()
-
                     all_streams = self.streams
 
                     if self._replication_torture_level is not None:
@@ -180,7 +175,7 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.upto_token:
+                        if stream.last_token == stream.current_token():
                             continue
 
                         if self._replication_torture_level:
@@ -192,7 +187,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.upto_token,
+                            stream.current_token(),
                         )
                         try:
                             updates, current_token = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..29199f5b46 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,27 +24,61 @@ Each stream is defined by the following information:
     current_token:      The function that returns the current token for the stream
     update_function:    The function that returns a list of updates between two tokens
 """
-
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+    AccountDataStream,
+    BackfillStream,
+    CachesStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PublicRoomsStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+    UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
 
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
-        events.EventsStream,
-        _base.BackfillStream,
-        _base.PresenceStream,
-        _base.TypingStream,
-        _base.ReceiptsStream,
-        _base.PushRulesStream,
-        _base.PushersStream,
-        _base.CachesStream,
-        _base.PublicRoomsStream,
-        _base.DeviceListsStream,
-        _base.ToDeviceStream,
-        federation.FederationStream,
-        _base.TagAccountDataStream,
-        _base.AccountDataStream,
-        _base.GroupServerStream,
-        _base.UserSignatureStream,
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        GroupServerStream,
+        UserSignatureStream,
     )
 }
+
+__all__ = [
+    "STREAMS_MAP",
+    "BackfillStream",
+    "PresenceStream",
+    "TypingStream",
+    "ReceiptsStream",
+    "PushRulesStream",
+    "PushersStream",
+    "CachesStream",
+    "PublicRoomsStream",
+    "DeviceListsStream",
+    "ToDeviceStream",
+    "TagAccountDataStream",
+    "AccountDataStream",
+    "GroupServerStream",
+    "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..abf5c6c6a8 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -17,10 +17,12 @@
 import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, List, Optional, Tuple
 
 import attr
 
+from synapse.types import JsonDict
+
 logger = logging.getLogger(__name__)
 
 
@@ -94,9 +96,13 @@ PublicRoomsStreamRow = namedtuple(
         "network_id",  # str, optional
     ),
 )
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
+
+
+@attr.s
+class DeviceListsStreamRow:
+    entity = attr.ib(type=str)
+
+
 ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
 TagAccountDataStreamRow = namedtuple(
     "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
@@ -115,13 +121,12 @@ class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
     NAME = None  # type: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
-    _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
     def parse_row(cls, row):
@@ -142,26 +147,15 @@ class Stream(object):
         # The token from which we last asked for updates
         self.last_token = self.current_token()
 
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
-
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
-        """
-        self.upto_token = self.current_token()
-
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token()
 
     async def get_updates(self):
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
             Deferred[Tuple[List[Tuple[int, Any]], int]:
@@ -174,44 +168,45 @@ class Stream(object):
 
         return updates, current_token
 
-    async def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, from_token: int
+    ) -> Tuple[List[Tuple[int, JsonDict]], int]:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            Resolves to a pair `(updates, new_last_token)`, where `updates` is
+            a list of `(token, row)` entries and `new_last_token` is the new
+            position in stream.
         """
+
         if from_token in ("NOW", "now"):
-            return [], self.upto_token
+            return [], self.current_token()
 
-        current_token = self.upto_token
+        current_token = self.current_token()
 
         from_token = int(from_token)
 
         if from_token == current_token:
             return [], current_token
 
-        logger.info("get_updates_since: %s", self.__class__)
-        if self._LIMITED:
-            rows = await self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        rows = await self.update_function(
+            from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+        )
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = await self.update_function(from_token, current_token)
+        # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
+        rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
 
         updates = [(row[0], row[1:]) for row in rows]
 
         # check we didn't get more rows than the limit.
         # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
+        if len(updates) >= MAX_EVENTS_BEHIND:
             raise Exception("stream %s has fallen behind" % (self.NAME))
 
+        # The update function didn't hit the limit, so we must have got all
+        # the updates to `current_token`, and can return that as our new
+        # stream position.
         return updates, current_token
 
     def current_token(self):
@@ -223,9 +218,8 @@ class Stream(object):
         """
         raise NotImplementedError()
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    def update_function(self, from_token, current_token, limit):
+        """Get updates between from_token and to_token.
 
         Returns:
             Deferred(list(tuple)): the first entry in the tuple is the token for
@@ -253,7 +247,6 @@ class BackfillStream(Stream):
 
 class PresenceStream(Stream):
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
@@ -268,7 +261,6 @@ class PresenceStream(Stream):
 
 class TypingStream(Stream):
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
@@ -363,11 +355,11 @@ class PublicRoomsStream(Stream):
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
@@ -457,7 +449,6 @@ class UserSignatureStream(Stream):
     """
 
     NAME = "user_signature"
-    _LIMITED = False
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs):