diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 81b85352b1..28dbc6fcba 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,14 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import federation, login, membership, register, send_event
+from synapse.replication.http import (
+ devices,
+ federation,
+ login,
+ membership,
+ register,
+ send_event,
+)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -30,3 +37,4 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
+ devices.register_servlets(hs, self)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
new file mode 100644
index 0000000000..e32aac0a25
--- /dev/null
+++ b/synapse/replication/http/devices.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
+ """Ask master to resync the device list for a user by contacting their
+ server.
+
+ This must happen on master so that the results can be correctly cached in
+ the database and streamed to workers.
+
+ Request format:
+
+ POST /_synapse/replication/user_device_resync/:user_id
+
+ {}
+
+ Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+ response, e.g.:
+
+ {
+ "user_id": "@alice:example.org",
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": { ... },
+ "device_display_name": "Alice's Mobile Phone"
+ }
+ ]
+ }
+ """
+
+ NAME = "user_device_resync"
+ PATH_ARGS = ("user_id",)
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+
+ self.device_list_updater = hs.get_device_handler().device_list_updater
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ def _serialize_payload(user_id):
+ return {}
+
+ async def _handle_request(self, request, user_id):
+ user_devices = await self.device_list_updater.user_device_resync(user_id)
+
+ return 200, user_devices
+
+
+def register_servlets(hs, http_server):
+ ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 915cfb9430..0c4aca1291 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -75,6 +75,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
+ self.registration_handler.check_registration_ratelimit(content["address"])
+
await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1d8..456bc005a0 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict
import six
@@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
self.hs = hs
- def stream_positions(self):
+ def stream_positions(self) -> Dict[str, int]:
+ """
+ Get the current positions of all the streams this store wants to subscribe to
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 563ce0fc53..fead78388c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,10 +16,17 @@
"""
import logging
+from typing import Dict
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
FederationAckCommand,
InvalidateCacheCommand,
@@ -27,7 +34,6 @@ from .commands import (
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b64f3f44b5..afaf002fe6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9e45429d49..8512923eae 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -88,8 +88,7 @@ TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
AccountDataStreamRow = namedtuple(
- "AccountDataStream",
- ("user_id", "room_id", "data_type", "data"), # str # str # str # dict
+ "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
@@ -421,8 +420,8 @@ class AccountDataStream(Stream):
results = list(room_results)
results.extend(
- (stream_id, user_id, None, account_data_type, content)
- for stream_id, user_id, account_data_type, content in global_results
+ (stream_id, user_id, None, account_data_type)
+ for stream_id, user_id, account_data_type in global_results
)
return results
|