diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f0e..c8056b0c0c 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -110,14 +110,14 @@ class ReplicationEndpoint(object):
return {}
@abc.abstractmethod
- def _handle_request(self, request, **kwargs):
+ async def _handle_request(self, request, **kwargs):
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
- Deferred[dict]: A JSON serialisable dict to be used as response
- body of request.
+ tuple[int, dict]: HTTP status code and a JSON serialisable dict
+ to be used as response body of request.
"""
pass
@@ -180,7 +180,7 @@ class ReplicationEndpoint(object):
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise
- logger.warn("%s request timed out", cls.NAME)
+ logger.warning("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f16955954..9af4e7e173 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request):
+ async def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
@@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
- context = yield EventContext.deserialize(
- self.store, event_payload["context"]
- )
+ context = EventContext.deserialize(self.store, event_payload["context"])
event_and_contexts.append((event, context))
logger.info("Got %d events from federation", len(event_and_contexts))
- yield self.federation_handler.persist_events_and_notify(
+ await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
)
@@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
- @defer.inlineCallbacks
- def _handle_request(self, request, edu_type):
+ async def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
@@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
logger.info("Got %r edu from %s", edu_type, origin)
- result = yield self.registry.on_edu(edu_type, origin, edu_content)
+ result = await self.registry.on_edu(edu_type, origin, edu_content)
return 200, result
@@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
"""
return {"args": args}
- @defer.inlineCallbacks
- def _handle_request(self, request, query_type):
+ async def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
@@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
logger.info("Got %r query", query_type)
- result = yield self.registry.on_query(query_type, args)
+ result = await self.registry.on_query(query_type, args)
return 200, result
@@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""
return {}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id):
- yield self.store.clean_room_for_join(room_id)
+ async def _handle_request(self, request, room_id):
+ await self.store.clean_room_for_join(room_id)
return 200, {}
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 786f5232b2..798b9d3af5 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest
)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477ad..cc1f249740 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
@@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id)
- yield self.federation_handler.do_invite_join(
+ await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content
)
@@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"remote_room_hosts": remote_room_hosts,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
- event = yield self.federation_handler.do_remotely_reject_invite(
+ event = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id
)
ret = event.get_pdu_json()
@@ -148,9 +144,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
+ logger.warning("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(user_id, room_id)
+ await self.store.locally_reject_invite(user_id, room_id)
ret = {}
return 200, ret
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 38260256cf..915cfb9430 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -74,11 +72,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"address": address,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
- yield self.registration_handler.register_with_store(
+ await self.registration_handler.register_with_store(
user_id=user_id,
password_hash=content["password_hash"],
was_guest=content["was_guest"],
@@ -117,14 +114,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
"""
return {"auth_result": auth_result, "access_token": access_token}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
auth_result = content["auth_result"]
access_token = content["access_token"]
- yield self.registration_handler.post_registration_actions(
+ await self.registration_handler.post_registration_actions(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index adb9b2f7f4..9bafd60b14 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
return payload
- @defer.inlineCallbacks
- def _handle_request(self, request, event_id):
+ async def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
@@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
event = EventType(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"])
- context = yield EventContext.deserialize(self.store, content["context"])
+ context = EventContext.deserialize(self.store, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
@@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
- yield self.event_creation_handler.persist_and_notify_client_event(
+ await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1d8..456bc005a0 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict
import six
@@ -44,7 +45,14 @@ class BaseSlavedStore(SQLBaseStore):
self.hs = hs
- def stream_positions(self):
+ def stream_positions(self) -> Dict[str, int]:
+ """
+ Get the current positions of all the streams this store wants to subscribe to
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 61557665a7..de50748c30 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
- result["device_lists"] = self._device_list_id_gen.get_current_token()
+ # The user signature stream uses the same stream ID generator as the
+ # device list stream, so set them both to the device list ID
+ # generator's current token.
+ current_token = self._device_list_id_gen.get_current_token()
+ result[DeviceListsStream.NAME] = current_token
+ result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "device_lists":
+ if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ elif stream_name == UserSignatureStream.NAME:
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index a44ceb00e7..fead78388c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,10 +16,17 @@
"""
import logging
+from typing import Dict
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
FederationAckCommand,
InvalidateCacheCommand,
@@ -27,7 +34,6 @@ from .commands import (
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -42,7 +48,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -68,13 +74,13 @@ class ReplicationClientFactory(ReconnectingClientFactory):
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -138,11 +144,13 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
@@ -168,7 +176,7 @@ class ReplicationClientHandler(object):
if self.connection:
self.connection.send_command(cmd)
else:
- logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 5ffdf2675d..afaf002fe6 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,7 +48,7 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
@@ -65,6 +65,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import Clock
from synapse.util.stringutils import random_string
from .commands import (
@@ -249,7 +250,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return handler(cmd)
def close(self):
- logger.warn("[%s] Closing connection", self.id())
+ logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection()
self.on_connection_closed()
@@ -558,11 +559,80 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+
+ Returns:
+ Deferred|None
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c259..9e45429d49 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -438,3 +439,20 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token
+ self.update_function = store.get_all_user_signature_changes_for_remotes
+
+ super(UserSignatureStream, self).__init__(hs)
|