summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/devices.py13
-rw-r--r--synapse/replication/tcp/client.py20
-rw-r--r--synapse/replication/tcp/protocol.py74
-rw-r--r--synapse/replication/tcp/streams/__init__.py1
-rw-r--r--synapse/replication/tcp/streams/_base.py18
6 files changed, 125 insertions, 11 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1d8..456bc005a0 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Dict
 
 import six
 
@@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
 
         self.hs = hs
 
-    def stream_positions(self):
+    def stream_positions(self) -> Dict[str, int]:
+        """
+        Get the current positions of all the streams this store wants to subscribe to
+
+        Returns:
+            map from stream name to the most recent update we have for
+            that stream (ie, the point we want to start replicating from)
+        """
         pos = {}
         if self._cache_id_gen:
             pos["caches"] = self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 61557665a7..de50748c30 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,6 +15,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.data_stores.main.devices import DeviceWorkerStore
 from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
 
     def stream_positions(self):
         result = super(SlavedDeviceStore, self).stream_positions()
-        result["device_lists"] = self._device_list_id_gen.get_current_token()
+        # The user signature stream uses the same stream ID generator as the
+        # device list stream, so set them both to the device list ID
+        # generator's current token.
+        current_token = self._device_list_id_gen.get_current_token()
+        result[DeviceListsStream.NAME] = current_token
+        result[UserSignatureStream.NAME] = current_token
         return result
 
     def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "device_lists":
+        if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
             for row in rows:
                 self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+        elif stream_name == UserSignatureStream.NAME:
+            for row in rows:
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 563ce0fc53..fead78388c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,10 +16,17 @@
 """
 
 import logging
+from typing import Dict
 
 from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+    AbstractReplicationClientHandler,
+    ClientReplicationStreamProtocol,
+)
+
 from .commands import (
     FederationAckCommand,
     InvalidateCacheCommand,
@@ -27,7 +34,6 @@ from .commands import (
     UserIpCommand,
     UserSyncCommand,
 )
-from .protocol import ClientReplicationStreamProtocol
 
 logger = logging.getLogger(__name__)
 
@@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
 
     maxDelay = 30  # Try at least once every N seconds
 
-    def __init__(self, hs, client_name, handler):
+    def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
         self.client_name = client_name
         self.handler = handler
         self.server_name = hs.config.server_name
@@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
     """A base handler that can be passed to the ReplicationClientFactory.
 
     By default proxies incoming replication data to the SlaveStore.
     """
 
-    def __init__(self, store):
+    def __init__(self, store: BaseSlavedStore):
         self.store = store
 
         # The current connection. None if we are currently (re)connecting
@@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
         if d:
             d.callback(data)
 
-    def get_streams_to_replicate(self):
+    def get_streams_to_replicate(self) -> Dict[str, int]:
         """Called when a new connection has been established and we need to
         subscribe to streams.
 
-        Returns a dictionary of stream name to token.
+        Returns:
+            map from stream name to the most recent update we have for
+            that stream (ie, the point we want to start replicating from)
         """
         args = self.store.stream_positions()
         user_account_data = args.pop("user_account_data", None)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b64f3f44b5..afaf002fe6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
-
+import abc
 import fcntl
 import logging
 import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from .commands import (
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.streamer.lost_connection(self)
 
 
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+    """
+    The interface for the handler that should be passed to
+    ClientReplicationStreamProtocol
+    """
+
+    @abc.abstractmethod
+    def on_rdata(self, stream_name, token, rows):
+        """Called to handle a batch of replication data with a given stream token.
+
+        Args:
+            stream_name (str): name of the replication stream for this batch of rows
+            token (int): stream token for this batch of rows
+            rows (list): a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+
+        Returns:
+            Deferred|None
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_position(self, stream_name, token):
+        """Called when we get new position data."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def on_sync(self, data):
+        """Called when get a new SYNC command."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_streams_to_replicate(self):
+        """Called when a new connection has been established and we need to
+        subscribe to streams.
+
+        Returns:
+            map from stream name to the most recent update we have for
+            that stream (ie, the point we want to start replicating from)
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_currently_syncing_users(self):
+        """Get the list of currently syncing users (if any). This is called
+        when a connection has been established and we need to send the
+        currently syncing users."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def update_connection(self, connection):
+        """Called when a connection has been established (or lost with None).
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def finished_connecting(self):
+        """Called when we have successfully subscribed and caught up to all
+        streams we're interested in.
+        """
+        raise NotImplementedError()
+
+
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
 
-    def __init__(self, client_name, server_name, clock, handler):
+    def __init__(
+        self,
+        client_name: str,
+        server_name: str,
+        clock: Clock,
+        handler: AbstractReplicationClientHandler,
+    ):
         BaseReplicationStreamProtocol.__init__(self, clock)
 
         self.client_name = client_name
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
         _base.TagAccountDataStream,
         _base.AccountDataStream,
         _base.GroupServerStream,
+        _base.UserSignatureStream,
     )
 }
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c259..9e45429d49 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
     "GroupsStreamRow",
     ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
 )
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
 
 
 class Stream(object):
@@ -438,3 +439,20 @@ class GroupServerStream(Stream):
         self.update_function = store.get_all_groups_changes
 
         super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+    """A user has signed their own device with their user-signing key
+    """
+
+    NAME = "user_signature"
+    _LIMITED = False
+    ROW_TYPE = UserSignatureStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_device_stream_token
+        self.update_function = store.get_all_user_signature_changes_for_remotes
+
+        super(UserSignatureStream, self).__init__(hs)