summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
authorBen Banfield-Zanin <benbz@matrix.org>2020-09-15 11:44:49 +0100
committerBen Banfield-Zanin <benbz@matrix.org>2020-09-15 11:44:49 +0100
commit1a7d96aa6ff81638f2ea696fdee2ec44e7bff75a (patch)
tree1839e80f89c53b34ff1b36974305c6cb0c94aab4 /synapse/replication
parentFix group server for older synapse (diff)
parentClarify changelog. (diff)
downloadsynapse-bbz/info-mainline-1.20.0.tar.xz
Merge remote-tracking branch 'origin/release-v1.20.0' into bbz/info-mainline-1.20.0 github/bbz/info-mainline-1.20.0 bbz/info-mainline-1.20.0
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py5
-rw-r--r--synapse/replication/http/_base.py41
-rw-r--r--synapse/replication/http/devices.py2
-rw-r--r--synapse/replication/http/federation.py15
-rw-r--r--synapse/replication/http/login.py2
-rw-r--r--synapse/replication/http/membership.py98
-rw-r--r--synapse/replication/http/presence.py4
-rw-r--r--synapse/replication/http/register.py8
-rw-r--r--synapse/replication/http/send_event.py7
-rw-r--r--synapse/replication/http/streams.py2
-rw-r--r--synapse/replication/slave/storage/_base.py6
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py14
-rw-r--r--synapse/replication/slave/storage/account_data.py17
-rw-r--r--synapse/replication/slave/storage/appservice.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py8
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py13
-rw-r--r--synapse/replication/slave/storage/devices.py15
-rw-r--r--synapse/replication/slave/storage/directory.py2
-rw-r--r--synapse/replication/slave/storage/events.py24
-rw-r--r--synapse/replication/slave/storage/filtering.py6
-rw-r--r--synapse/replication/slave/storage/groups.py11
-rw-r--r--synapse/replication/slave/storage/keys.py2
-rw-r--r--synapse/replication/slave/storage/presence.py11
-rw-r--r--synapse/replication/slave/storage/profile.py2
-rw-r--r--synapse/replication/slave/storage/push_rule.py17
-rw-r--r--synapse/replication/slave/storage/pushers.py11
-rw-r--r--synapse/replication/slave/storage/receipts.py19
-rw-r--r--synapse/replication/slave/storage/registration.py2
-rw-r--r--synapse/replication/slave/storage/room.py11
-rw-r--r--synapse/replication/slave/storage/transactions.py2
-rw-r--r--synapse/replication/tcp/__init__.py2
-rw-r--r--synapse/replication/tcp/client.py16
-rw-r--r--synapse/replication/tcp/commands.py34
-rw-r--r--synapse/replication/tcp/handler.py413
-rw-r--r--synapse/replication/tcp/protocol.py62
-rw-r--r--synapse/replication/tcp/redis.py61
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py96
-rw-r--r--synapse/replication/tcp/streams/events.py10
39 files changed, 584 insertions, 491 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py

index 19b69e0e11..a84a064c8d 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication" class ReplicationRestResource(JsonResource): def __init__(self, hs): - JsonResource.__init__(self, hs, canonical_json=False) + # We enable extracting jaeger contexts here as these are internal APIs. + super().__init__(hs, canonical_json=False, extract_context=True) self.register_servlets(hs) def register_servlets(self, hs): @@ -38,10 +39,10 @@ class ReplicationRestResource(JsonResource): federation.register_servlets(hs, self) presence.register_servlets(hs, self) membership.register_servlets(hs, self) + streams.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: login.register_servlets(hs, self) register.register_servlets(hs, self) devices.register_servlets(hs, self) - streams.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 793cef6c26..ba16f22c91 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -16,32 +16,24 @@ import abc import logging import re +import urllib from inspect import signature from typing import Dict, List, Tuple -from six import raise_from -from six.moves import urllib - -from twisted.internet import defer - from synapse.api.errors import ( CodeMessageException, HttpResponseException, RequestSendFailed, SynapseError, ) -from synapse.logging.opentracing import ( - inject_active_span_byte_dict, - trace, - trace_servlet, -) +from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) -class ReplicationEndpoint(object): +class ReplicationEndpoint: """Helper base class for defining new replication HTTP endpoints. This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..` @@ -98,16 +90,16 @@ class ReplicationEndpoint(object): # assert here that sub classes don't try and use the name. assert ( "instance_name" not in self.PATH_ARGS - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert ( "instance_name" not in signature(self.__class__._serialize_payload).parameters - ), "`instance_name` is a reserved paramater name" + ), "`instance_name` is a reserved parameter name" assert self.METHOD in ("PUT", "POST", "GET") @abc.abstractmethod - def _serialize_payload(**kwargs): + async def _serialize_payload(**kwargs): """Static method that is called when creating a request. Concrete implementations should have explicit parameters (rather than @@ -116,9 +108,8 @@ class ReplicationEndpoint(object): argument list. Returns: - Deferred[dict]|dict: If POST/PUT request then dictionary must be - JSON serialisable, otherwise must be appropriate for adding as - query args. + dict: If POST/PUT request then dictionary must be JSON serialisable, + otherwise must be appropriate for adding as query args. """ return {} @@ -150,8 +141,7 @@ class ReplicationEndpoint(object): instance_map = hs.config.worker.instance_map @trace(opname="outgoing_replication_request") - @defer.inlineCallbacks - def send_request(instance_name="master", **kwargs): + async def send_request(instance_name="master", **kwargs): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") if instance_name == "master": @@ -165,7 +155,7 @@ class ReplicationEndpoint(object): "Instance %r not in 'instance_map' config" % (instance_name,) ) - data = yield cls._serialize_payload(**kwargs) + data = await cls._serialize_payload(**kwargs) url_args = [ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS @@ -203,7 +193,7 @@ class ReplicationEndpoint(object): headers = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers, None, check_destination=False) try: - result = yield request_func(uri, data, headers=headers) + result = await request_func(uri, data, headers=headers) break except CodeMessageException as e: if e.code != 504 or not cls.RETRY_ON_TIMEOUT: @@ -213,14 +203,14 @@ class ReplicationEndpoint(object): # If we timed out we probably don't need to worry about backing # off too much, but lets just wait a little anyway. - yield clock.sleep(1) + await clock.sleep(1) except HttpResponseException as e: # We convert to SynapseError as we know that it was a SynapseError # on the master process that we should send to the client. (And # importantly, not stack traces everywhere) raise e.to_synapse_error() except RequestSendFailed as e: - raise_from(SynapseError(502, "Failed to talk to master"), e) + raise SynapseError(502, "Failed to talk to master") from e return result @@ -242,11 +232,8 @@ class ReplicationEndpoint(object): args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) - handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler) - # We don't let register paths trace this servlet using the default tracing - # options because we wish to extract the context explicitly. http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, trace=False + method, [pattern], handler, self.__class__.__name__, ) def _cached_handler(self, request, txn_id, **kwargs): diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index e32aac0a25..20f3ba76c0 100644 --- a/synapse/replication/http/devices.py +++ b/synapse/replication/http/devices.py
@@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(user_id): + async def _serialize_payload(user_id): return {} async def _handle_request(self, request, user_id): diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index c287c4e269..6b56315148 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py
@@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext @@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): self.federation_handler = hs.get_handlers().federation_handler @staticmethod - @defer.inlineCallbacks - def _serialize_payload(store, event_and_contexts, backfilled): + async def _serialize_payload(store, event_and_contexts, backfilled): """ Args: store @@ -78,7 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """ event_payloads = [] for event, context in event_and_contexts: - serialized_context = yield context.serialize(event, store) + serialized_context = await context.serialize(event, store) event_payloads.append( { @@ -154,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - def _serialize_payload(edu_type, origin, content): + async def _serialize_payload(edu_type, origin, content): return {"origin": origin, "content": content} async def _handle_request(self, request, edu_type): @@ -197,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): self.registry = hs.get_federation_registry() @staticmethod - def _serialize_payload(query_type, args): + async def _serialize_payload(query_type, args): """ Args: query_type (str) @@ -238,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - def _serialize_payload(room_id, args): + async def _serialize_payload(room_id, args): """ Args: room_id (str) @@ -273,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint): self.store = hs.get_datastore() @staticmethod - def _serialize_payload(room_id, room_version): + async def _serialize_payload(room_id, room_version): return {"room_version": room_version.identifier} async def _handle_request(self, request, room_id): diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 798b9d3af5..fb326bb869 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py
@@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, device_id, initial_display_name, is_guest): + async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): """ Args: device_id (str|None): Device ID to use, if None a new one is diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index a7174c4a8f..741329ab5f 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py
@@ -14,11 +14,11 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import Requester, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.util.distributor import user_joined_room, user_left_room if TYPE_CHECKING: @@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): + async def _serialize_payload( + requester, room_id, user_id, remote_room_hosts, content + ): """ Args: requester(Requester) @@ -88,49 +90,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): - """Rejects the invite for the user and room. + """Rejects an out-of-band invite we have received from a remote server Request format: - POST /_synapse/replication/remote_reject_invite/:room_id/:user_id + POST /_synapse/replication/remote_reject_invite/:event_id { + "txn_id": ..., "requester": ..., - "remote_room_hosts": [...], "content": { ... } } """ NAME = "remote_reject_invite" - PATH_ARGS = ("room_id", "user_id") + PATH_ARGS = ("invite_event_id",) - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs) - self.federation_handler = hs.get_handlers().federation_handler self.store = hs.get_datastore() self.clock = hs.get_clock() self.member_handler = hs.get_room_member_handler() @staticmethod - def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content): + async def _serialize_payload( # type: ignore + invite_event_id: str, + txn_id: Optional[str], + requester: Requester, + content: JsonDict, + ): """ Args: - requester(Requester) - room_id (str) - user_id (str) - remote_room_hosts (list[str]): Servers to try and reject via + invite_event_id: ID of the invite to be rejected + txn_id: optional transaction ID supplied by the client + requester: user making the rejection request, according to the access token + content: additional content to include in the rejection event. + Normally an empty dict. """ return { + "txn_id": txn_id, "requester": requester.serialize(), - "remote_room_hosts": remote_room_hosts, "content": content, } - async def _handle_request(self, request, room_id, user_id): + async def _handle_request(self, request, invite_event_id): content = parse_json_object_from_request(request) - remote_room_hosts = content["remote_room_hosts"] + txn_id = content["txn_id"] event_content = content["content"] requester = Requester.deserialize(self.store, content["requester"]) @@ -138,60 +145,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): if requester.user: request.authenticated_entity = requester.user.to_string() - logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) - - try: - event, stream_id = await self.federation_handler.do_remotely_reject_invite( - remote_room_hosts, room_id, user_id, event_content, - ) - event_id = event.event_id - except Exception as e: - # if we were unable to reject the exception, just mark - # it as rejected on our end and plough ahead. - # - # The 'except' clause is very broad, but we need to - # capture everything from DNS failures upwards - # - logger.warning("Failed to reject invite: %s", e) - - stream_id = await self.member_handler.locally_reject_invite( - user_id, room_id - ) - event_id = None + # hopefully we're now on the master, so this won't recurse! + event_id, stream_id = await self.member_handler.remote_reject_invite( + invite_event_id, txn_id, requester, event_content, + ) return 200, {"event_id": event_id, "stream_id": stream_id} -class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint): - """Rejects the invite for the user and room locally. - - Request format: - - POST /_synapse/replication/locally_reject_invite/:room_id/:user_id - - {} - """ - - NAME = "locally_reject_invite" - PATH_ARGS = ("room_id", "user_id") - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - self.member_handler = hs.get_room_member_handler() - - @staticmethod - def _serialize_payload(room_id, user_id): - return {} - - async def _handle_request(self, request, room_id, user_id): - logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id) - - stream_id = await self.member_handler.locally_reject_invite(user_id, room_id) - - return 200, {"stream_id": stream_id} - - class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): """Notifies that a user has joined or left the room @@ -215,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): self.distributor = hs.get_distributor() @staticmethod - def _serialize_payload(room_id, user_id, change): + async def _serialize_payload(room_id, user_id, change): """ Args: room_id (str) @@ -245,4 +206,3 @@ def register_servlets(hs, http_server): ReplicationRemoteJoinRestServlet(hs).register(http_server) ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) - ReplicationLocallyRejectInviteRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index ea1b33331b..bc9aa82cb4 100644 --- a/synapse/replication/http/presence.py +++ b/synapse/replication/http/presence.py
@@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - def _serialize_payload(user_id): + async def _serialize_payload(user_id): return {} async def _handle_request(self, request, user_id): @@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint): self._presence_handler = hs.get_presence_handler() @staticmethod - def _serialize_payload(user_id, state, ignore_status_msg=False): + async def _serialize_payload(user_id, state, ignore_status_msg=False): return { "state": state, "ignore_status_msg": ignore_status_msg, diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 0c4aca1291..a02b27474d 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py
@@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload( + async def _serialize_payload( user_id, password_hash, was_guest, @@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin, user_type, address, + shadow_banned, ): """ Args: @@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): user_type (str|None): type of user. One of the values from api.constants.UserTypes, or None for a normal user. address (str|None): the IP address used to perform the regitration. + shadow_banned (bool): Whether to shadow-ban the user """ return { "password_hash": password_hash, @@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): "admin": admin, "user_type": user_type, "address": address, + "shadow_banned": shadow_banned, } async def _handle_request(self, request, user_id): @@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): admin=content["admin"], user_type=content["user_type"], address=content["address"], + shadow_banned=content["shadow_banned"], ) return 200, {} @@ -105,7 +109,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - def _serialize_payload(user_id, auth_result, access_token): + async def _serialize_payload(user_id, auth_result, access_token): """ Args: user_id (str): The user ID that consented diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index c981723c1a..f13d452426 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py
@@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext @@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.clock = hs.get_clock() @staticmethod - @defer.inlineCallbacks - def _serialize_payload( + async def _serialize_payload( event_id, store, event, context, requester, ratelimit, extra_users ): """ @@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = yield context.serialize(event, store) + serialized_context = await context.serialize(event, store) payload = { "event": event.get_pdu_json(), diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index bde97eef32..309159e304 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py
@@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): self.streams = hs.get_replication_streams() @staticmethod - def _serialize_payload(stream_name, from_token, upto_token): + async def _serialize_payload(stream_name, from_token, upto_token): return {"from_token": from_token, "upto_token": upto_token} async def _handle_request(self, request, stream_name): diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f9e2533e96..60f2e1245f 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py
@@ -16,8 +16,8 @@ import logging from typing import Optional -from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) class BaseSlavedStore(CacheInvalidationWorkerStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = MultiWriterIdGenerator( diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..eb74903d68 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -16,14 +16,14 @@ from synapse.storage.util.id_generators import _load_current_id -class SlavedIdTracker(object): +class SlavedIdTracker: def __init__(self, db_conn, table, column, extra_tables=[], step=1): self.step = step self._current = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: - self.advance(_load_current_id(db_conn, table, column)) + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, new_id): + def advance(self, instance_name, new_id): self._current = (max if self.step > 0 else min)(self._current, new_id) def get_current_token(self): @@ -33,3 +33,11 @@ class SlavedIdTracker(object): int """ return self._current + + def get_current_token_for_writer(self, instance_name: str) -> int: + """Returns the position of the given writer. + + For streams with single writers this is equivalent to + `get_current_token`. + """ + return self.get_current_token() diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..bb66ba9b80 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py
@@ -16,13 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore -from synapse.storage.data_stores.main.tags import TagsWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data", @@ -39,13 +40,13 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved return self._account_data_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "tag_account_data": - self._account_data_id_gen.advance(token) + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == "account_data": - self._account_data_id_gen.advance(token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a67fbeffb7..0f8d7037bd 100644 --- a/synapse/replication/slave/storage/appservice.py +++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.appservice import ( +from synapse.storage.databases.main.appservice import ( ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, ) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1a38f53dfb..a6fdedde63 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,22 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore class SlavedClientIpStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, max_entries=50000 ) - def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): + async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): now = int(self._clock.time_msec()) key = (user_id, access_token, ip) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..533d927701 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,17 +15,18 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import ToDeviceStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs) self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_max_stream_id", "stream_id" + db_conn, "device_inbox", "stream_id" ) self._device_inbox_stream_cache = StreamChangeCache( "DeviceInboxStreamChangeCache", @@ -44,8 +45,8 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): ) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "to_device": - self._device_inbox_id_gen.advance(token) + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) for row in rows: if row.entity.startswith("@"): self._device_inbox_stream_cache.entity_has_changed( diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 9d8067342f..3b788c9625 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py
@@ -16,14 +16,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream -from synapse.storage.data_stores.main.devices import DeviceWorkerStore -from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.devices import DeviceWorkerStore +from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedDeviceStore, self).__init__(database, db_conn, hs) self.hs = hs @@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(token) + self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 8b9717c46f..1945bcf9a8 100644 --- a/synapse/replication/slave/storage/directory.py +++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.directory import DirectoryWorkerStore +from synapse.storage.databases.main.directory import DirectoryWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 1a1a50a24f..da1cc836cf 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py
@@ -15,18 +15,18 @@ # limitations under the License. import logging -from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore -from synapse.storage.data_stores.main.event_push_actions import ( +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.event_federation import EventFederationWorkerStore +from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) -from synapse.storage.data_stores.main.events_worker import EventsWorkerStore -from synapse.storage.data_stores.main.relations import RelationsWorkerStore -from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore -from synapse.storage.data_stores.main.signatures import SignatureWorkerStore -from synapse.storage.data_stores.main.state import StateGroupWorkerStore -from synapse.storage.data_stores.main.stream import StreamWorkerStore -from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore -from synapse.storage.database import Database +from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.relations import RelationsWorkerStore +from synapse.storage.databases.main.roommember import RoomMemberWorkerStore +from synapse.storage.databases.main.signatures import SignatureWorkerStore +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.databases.main.stream import StreamWorkerStore +from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -55,11 +55,11 @@ class SlavedEventStore( RelationsWorkerStore, BaseSlavedStore, ): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedEventStore, self).__init__(database, db_conn, hs) events_max = self._stream_id_gen.get_current_token() - curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict( + curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict( db_conn, "current_state_delta_stream", entity_column="room_id", diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index bcb0688954..2562b6fc38 100644 --- a/synapse/replication/slave/storage/filtering.py +++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.filtering import FilteringStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.filtering import FilteringStore from ._base import BaseSlavedStore class SlavedFilteringStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedFilteringStore, self).__init__(database, db_conn, hs) # Filters are immutable so this cache doesn't need to be expired diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..567b4a5cc1 100644 --- a/synapse/replication/slave/storage/groups.py +++ b/synapse/replication/slave/storage/groups.py
@@ -15,13 +15,14 @@ from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import GroupServerStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.group_server import GroupServerWorkerStore from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedGroupServerStore, self).__init__(database, db_conn, hs) self.hs = hs @@ -38,8 +39,8 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): return self._group_updates_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "groups": - self._group_updates_id_gen.advance(token) + if stream_name == GroupServerStream.NAME: + self._group_updates_id_gen.advance(instance_name, token) for row in rows: self._group_updates_stream_cache.entity_has_changed(row.user_id, token) diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 3def367ae9..961579751c 100644 --- a/synapse/replication/slave/storage/keys.py +++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.keys import KeyStore +from synapse.storage.databases.main.keys import KeyStore # KeyStore isn't really safe to use from a worker, but for now we do so and hope that # the races it creates aren't too bad. diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..025f6f6be8 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py
@@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.replication.tcp.streams import PresenceStream from synapse.storage import DataStore -from synapse.storage.data_stores.main.presence import PresenceStore -from synapse.storage.database import Database +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.presence import PresenceStore from synapse.util.caches.stream_change_cache import StreamChangeCache from ._base import BaseSlavedStore @@ -23,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker class SlavedPresenceStore(BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") @@ -42,8 +43,8 @@ class SlavedPresenceStore(BaseSlavedStore): return self._presence_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "presence": - self._presence_id_gen.advance(token) + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 28c508aad3..f85b20a071 100644 --- a/synapse/replication/slave/storage/profile.py +++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.storage.data_stores.main.profile import ProfileWorkerStore +from synapse.storage.databases.main.profile import ProfileWorkerStore class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore): diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..de904c943c 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,24 +14,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import PushRulesStream +from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from .events import SlavedEventStore class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): - def get_push_rules_stream_token(self): - return ( - self._push_rules_stream_id_gen.get_current_token(), - self._stream_id_gen.get_current_token(), - ) - def get_max_push_rules_stream_id(self): return self._push_rules_stream_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "push_rules": - self._push_rules_stream_id_gen.advance(token) + # We assert this for the benefit of mypy + assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker) + + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) for row in rows: self.get_push_rules_for_user.invalidate((row.user_id,)) self.get_push_rules_enabled_for_user.invalidate((row.user_id,)) diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..9da218bfe8 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py
@@ -14,15 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.pusher import PusherWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import PushersStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.pusher import PusherWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(SlavedPusherStore, self).__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] @@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): return self._pushers_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "pushers": - self._pushers_id_gen.advance(token) + if stream_name == PushersStream.NAME: + self._pushers_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..5c2986e050 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py
@@ -14,23 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import ReceiptsStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker -# So, um, we want to borrow a load of functions intended for reading from -# a DataStore, but we don't want to take functions that either write to the -# DataStore or are cached and don't have cache invalidation logic. -# -# Rather than write duplicate versions of those functions, or lift them to -# a common base class, we going to grab the underlying __func__ object from -# the method descriptor on the DataStore and chuck them into our class. - class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): # We instantiate this first as the ReceiptsWorkerStore constructor # needs to be able to call get_max_receipt_stream_id self._receipts_id_gen = SlavedIdTracker( @@ -52,8 +45,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): self.get_receipts_for_room.invalidate((room_id, receipt_type)) def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "receipts": - self._receipts_id_gen.advance(token) + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) for row in rows: self.invalidate_caches_for_receipt( row.room_id, row.receipt_type, row.user_id diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 4b8553e250..a40f064e2b 100644 --- a/synapse/replication/slave/storage/registration.py +++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.registration import RegistrationWorkerStore +from synapse.storage.databases.main.registration import RegistrationWorkerStore from ._base import BaseSlavedStore diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..80ae803ad9 100644 --- a/synapse/replication/slave/storage/room.py +++ b/synapse/replication/slave/storage/room.py
@@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.room import RoomWorkerStore -from synapse.storage.database import Database +from synapse.replication.tcp.streams import PublicRoomsStream +from synapse.storage.database import DatabasePool +from synapse.storage.databases.main.room import RoomWorkerStore from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker class RoomStore(RoomWorkerStore, BaseSlavedStore): - def __init__(self, database: Database, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn, hs): super(RoomStore, self).__init__(database, db_conn, hs) self._public_room_id_gen = SlavedIdTracker( db_conn, "public_room_list_stream", "stream_id" @@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore): return self._public_room_id_gen.get_current_token() def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == "public_rooms": - self._public_room_id_gen.advance(token) + if stream_name == PublicRoomsStream.NAME: + self._public_room_id_gen.advance(instance_name, token) return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index ac88e6b8c3..2091ac0df6 100644 --- a/synapse/replication/slave/storage/transactions.py +++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.data_stores.main.transactions import TransactionStore +from synapse.storage.databases.main.transactions import TransactionStore from ._base import BaseSlavedStore diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 523a1358d4..1b8718b11d 100644 --- a/synapse/replication/tcp/__init__.py +++ b/synapse/replication/tcp/__init__.py
@@ -25,7 +25,7 @@ Structure of the module: * command.py - the definitions of all the valid commands * protocol.py - the TCP protocol classes * resource.py - handles streaming stream updates to replications - * streams/ - the definitons of all the valid streams + * streams/ - the definitions of all the valid streams The general interaction of the classes are: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index df29732f51..d6ecf5b327 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@ # limitations under the License. """A replication client for use by synapse workers. """ -import heapq import logging from typing import TYPE_CHECKING, Dict, List, Tuple @@ -24,6 +23,7 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol +from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure if TYPE_CHECKING: - from synapse.server import HomeServer from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -104,6 +104,7 @@ class ReplicationDataHandler: self._clock = hs.get_clock() self._streams = hs.get_replication_streams() self._instance_name = hs.get_instance_name() + self._typing_handler = hs.get_typing_handler() # Map from stream to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. @@ -127,6 +128,12 @@ class ReplicationDataHandler: """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + if stream_name == TypingStream.NAME: + self._typing_handler.process_replication_rows(token, rows) + self.notifier.on_new_event( + "typing_key", token, rooms=[row.room_id for row in rows] + ) + if stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -211,9 +218,8 @@ class ReplicationDataHandler: waiting_list = self._streams_to_waiters.setdefault(stream_name, []) - # We insert into the list using heapq as it is more efficient than - # pushing then resorting each time. - heapq.heappush(waiting_list, (position, deferred)) + waiting_list.append((position, deferred)) + waiting_list.sort(key=lambda t: t[0]) # We measure here to get in flight counts and average waiting time. with Measure(self._clock, "repl.wait_for_stream_position"): diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py
@@ -19,17 +19,9 @@ allowed to be sent by which side. """ import abc import logging -import platform from typing import Tuple, Type -if platform.python_implementation() == "PyPy": - import json - - _json_encoder = json.JSONEncoder() -else: - import simplejson as json # type: ignore[no-redef] # noqa: F821 - - _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821 +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -54,7 +46,7 @@ class Command(metaclass=abc.ABCMeta): @abc.abstractmethod def to_line(self) -> str: - """Serialises the comamnd for the wire. Does not include the command + """Serialises the command for the wire. Does not include the command prefix. """ @@ -131,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -140,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -149,7 +141,7 @@ class RdataCommand(Command): class PositionCommand(Command): - """Sent by the server to tell the client the stream postition without + """Sent by the server to tell the client the stream position without needing to send an RDATA. Format:: @@ -188,7 +180,7 @@ class ErrorCommand(_SimpleCommand): class PingCommand(_SimpleCommand): - """Sent by either side as a keep alive. The data is arbitary (often timestamp) + """Sent by either side as a keep alive. The data is arbitrary (often timestamp) """ NAME = "PING" @@ -300,20 +292,22 @@ class FederationAckCommand(Command): Format:: - FEDERATION_ACK <token> + FEDERATION_ACK <instance_name> <token> """ NAME = "FEDERATION_ACK" - def __init__(self, token): + def __init__(self, instance_name, token): + self.instance_name = instance_name self.token = token @classmethod def from_line(cls, line): - return cls(int(line)) + instance_name, token = line.split(" ") + return cls(instance_name, int(token)) def to_line(self): - return str(self.token) + return "%s %s" % (self.instance_name, self.token) class RemovePusherCommand(Command): @@ -363,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -371,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..1c303f3a46 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py
@@ -13,15 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar +from typing import ( + Any, + Awaitable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) from prometheus_client import Counter +from typing_extensions import Deque from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import LaterGauge +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.client import DirectTcpReplicationClientFactory from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, @@ -43,8 +56,8 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + TypingStream, ) -from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) @@ -56,12 +69,16 @@ inbound_rdata_count = Counter( user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -invalidate_cache_counter = Counter( - "synapse_replication_tcp_resource_invalidate_cache", "" -) + user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +# the type of the entries in _command_queues_by_stream +_StreamCommandQueue = Deque[ + Tuple[Union[RdataCommand, PositionCommand], AbstractConnection] +] + + class ReplicationCommandHandler: """Handles incoming commands from replication as well as sending commands back out to connections. @@ -97,6 +114,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, TypingStream): + # Only add TypingStream as a source on the instance in charge of + # typing. + if hs.config.worker.writers.typing == hs.get_instance_name(): + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -108,12 +133,8 @@ class ReplicationCommandHandler: self._streams_to_replicate.append(stream) - self._position_linearizer = Linearizer( - "replication_position", clock=self._clock - ) - - # Map of stream to batched updates. See RdataCommand for info on how - # batching works. + # Map of stream name to batched updates. See RdataCommand for info on + # how batching works. self._pending_batches = {} # type: Dict[str, List[Any]] # The factory used to create connections. @@ -123,9 +144,6 @@ class ReplicationCommandHandler: # outgoing replication commands to.) self._connections = [] # type: List[AbstractConnection] - # For each connection, the incoming streams that are coming from that connection - self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] - LaterGauge( "synapse_replication_tcp_resource_total_connections", "", @@ -133,6 +151,32 @@ class ReplicationCommandHandler: lambda: len(self._connections), ) + # When POSITION or RDATA commands arrive, we stick them in a queue and process + # them in order in a separate background process. + + # the streams which are currently being processed by _unsafe_process_queue + self._processing_streams = set() # type: Set[str] + + # for each stream, a queue of commands that are awaiting processing, and the + # connection that they arrived on. + self._command_queues_by_stream = { + stream_name: _StreamCommandQueue() for stream_name in self._streams + } + + # For each connection, the incoming stream names that have received a POSITION + # from that connection. + self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] + + LaterGauge( + "synapse_replication_tcp_command_queue", + "Number of inbound RDATA/POSITION commands queued for processing", + ["stream_name"], + lambda: { + (stream_name,): len(queue) + for stream_name, queue in self._command_queues_by_stream.items() + }, + ) + self._is_master = hs.config.worker_app is None self._federation_sender = None @@ -143,15 +187,75 @@ class ReplicationCommandHandler: if self._is_master: self._server_notices_sender = hs.get_server_notices_sender() + def _add_command_to_stream_queue( + self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand] + ) -> None: + """Queue the given received command for processing + + Adds the given command to the per-stream queue, and processes the queue if + necessary + """ + stream_name = cmd.stream_name + queue = self._command_queues_by_stream.get(stream_name) + if queue is None: + logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name) + return + + queue.append((cmd, conn)) + + # if we're already processing this stream, there's nothing more to do: + # the new entry on the queue will get picked up in due course + if stream_name in self._processing_streams: + return + + # fire off a background process to start processing the queue. + run_as_background_process( + "process-replication-data", self._unsafe_process_queue, stream_name + ) + + async def _unsafe_process_queue(self, stream_name: str): + """Processes the command queue for the given stream, until it is empty + + Does not check if there is already a thread processing the queue, hence "unsafe" + """ + assert stream_name not in self._processing_streams + + self._processing_streams.add(stream_name) + try: + queue = self._command_queues_by_stream.get(stream_name) + while queue: + cmd, conn = queue.popleft() + try: + await self._process_command(cmd, conn, stream_name) + except Exception: + logger.exception("Failed to handle command %s", cmd) + finally: + self._processing_streams.discard(stream_name) + + async def _process_command( + self, + cmd: Union[PositionCommand, RdataCommand], + conn: AbstractConnection, + stream_name: str, + ) -> None: + if isinstance(cmd, PositionCommand): + await self._process_position(stream_name, conn, cmd) + elif isinstance(cmd, RdataCommand): + await self._process_rdata(stream_name, conn, cmd) + else: + # This shouldn't be possible + raise Exception("Unrecognised command %s in stream queue", cmd.NAME) + def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ if hs.config.redis.redis_enabled: + import txredisapi + from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, ) - import txredisapi logger.info( "Connecting to redis (host=%r port=%r)", @@ -198,7 +302,7 @@ class ReplicationCommandHandler: """ return self._streams_to_replicate - async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): + def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand): self.send_positions_to_connection(conn) def send_positions_to_connection(self, conn: AbstractConnection): @@ -217,57 +321,73 @@ class ReplicationCommandHandler: ) ) - async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand): + def on_USER_SYNC( + self, conn: AbstractConnection, cmd: UserSyncCommand + ) -> Optional[Awaitable[None]]: user_sync_counter.inc() if self._is_master: - await self._presence_handler.update_external_syncs_row( + return self._presence_handler.update_external_syncs_row( cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms ) + else: + return None - async def on_CLEAR_USER_SYNC( + def on_CLEAR_USER_SYNC( self, conn: AbstractConnection, cmd: ClearUserSyncsCommand - ): + ) -> Optional[Awaitable[None]]: if self._is_master: - await self._presence_handler.update_external_syncs_clear(cmd.instance_id) + return self._presence_handler.update_external_syncs_clear(cmd.instance_id) + else: + return None - async def on_FEDERATION_ACK( - self, conn: AbstractConnection, cmd: FederationAckCommand - ): + def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand): federation_ack_counter.inc() if self._federation_sender: - self._federation_sender.federation_ack(cmd.token) + self._federation_sender.federation_ack(cmd.instance_name, cmd.token) - async def on_REMOVE_PUSHER( + def on_REMOVE_PUSHER( self, conn: AbstractConnection, cmd: RemovePusherCommand - ): + ) -> Optional[Awaitable[None]]: remove_pusher_counter.inc() if self._is_master: - await self._store.delete_pusher_by_app_id_pushkey_user_id( - app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id - ) + return self._handle_remove_pusher(cmd) + else: + return None + + async def _handle_remove_pusher(self, cmd: RemovePusherCommand): + await self._store.delete_pusher_by_app_id_pushkey_user_id( + app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id + ) - self._notifier.on_new_replication_data() + self._notifier.on_new_replication_data() - async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand): + def on_USER_IP( + self, conn: AbstractConnection, cmd: UserIpCommand + ) -> Optional[Awaitable[None]]: user_ip_cache_counter.inc() if self._is_master: - await self._store.insert_client_ip( - cmd.user_id, - cmd.access_token, - cmd.ip, - cmd.user_agent, - cmd.device_id, - cmd.last_seen, - ) + return self._handle_user_ip(cmd) + else: + return None + + async def _handle_user_ip(self, cmd: UserIpCommand): + await self._store.insert_client_ip( + cmd.user_id, + cmd.access_token, + cmd.ip, + cmd.user_agent, + cmd.device_id, + cmd.last_seen, + ) - if self._server_notices_sender: - await self._server_notices_sender.on_user_ip(cmd.user_id) + assert self._server_notices_sender is not None + await self._server_notices_sender.on_user_ip(cmd.user_id) - async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): + def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): if cmd.instance_name == self._instance_name: # Ignore RDATA that are just our own echoes return @@ -275,42 +395,71 @@ class ReplicationCommandHandler: stream_name = cmd.stream_name inbound_rdata_count.labels(stream_name).inc() - try: - row = STREAMS_MAP[stream_name].parse_row(cmd.row) - except Exception: - logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row) - raise - - # We linearize here for two reasons: + # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the # same stream, and # 2. so we don't race with getting a POSITION command and fetching # missing RDATA. - with await self._position_linearizer.queue(cmd.stream_name): - # make sure that we've processed a POSITION for this stream *on this - # connection*. (A POSITION on another connection is no good, as there - # is no guarantee that we have seen all the intermediate updates.) - sbc = self._streams_by_connection.get(conn) - if not sbc or stream_name not in sbc: - # Let's drop the row for now, on the assumption we'll receive a - # `POSITION` soon and we'll catch up correctly then. - logger.debug( - "Discarding RDATA for unconnected stream %s -> %s", - stream_name, - cmd.token, - ) - return - - if cmd.token is None: - # I.e. this is part of a batch of updates for this stream (in - # which case batch until we get an update for the stream with a non - # None token). - self._pending_batches.setdefault(stream_name, []).append(row) - else: - # Check if this is the last of a batch of updates - rows = self._pending_batches.pop(stream_name, []) - rows.append(row) - await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) + + self._add_command_to_stream_queue(conn, cmd) + + async def _process_rdata( + self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand + ) -> None: + """Process an RDATA command + + Called after the command has been popped off the queue of inbound commands + """ + try: + row = STREAMS_MAP[stream_name].parse_row(cmd.row) + except Exception as e: + raise Exception( + "Failed to parse RDATA: %r %r" % (stream_name, cmd.row) + ) from e + + # make sure that we've processed a POSITION for this stream *on this + # connection*. (A POSITION on another connection is no good, as there + # is no guarantee that we have seen all the intermediate updates.) + sbc = self._streams_by_connection.get(conn) + if not sbc or stream_name not in sbc: + # Let's drop the row for now, on the assumption we'll receive a + # `POSITION` soon and we'll catch up correctly then. + logger.debug( + "Discarding RDATA for unconnected stream %s -> %s", + stream_name, + cmd.token, + ) + return + + if cmd.token is None: + # I.e. this is part of a batch of updates for this stream (in + # which case batch until we get an update for the stream with a non + # None token). + self._pending_batches.setdefault(stream_name, []).append(row) + return + + # Check if this is the last of a batch of updates + rows = self._pending_batches.pop(stream_name, []) + rows.append(row) + + stream = self._streams[stream_name] + + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) + + # Discard this data if this token is earlier than the current + # position. Note that streams can be reset (in which case you + # expect an earlier token), but that must be preceded by a + # POSITION command. + if cmd.token <= current_token: + logger.debug( + "Discarding RDATA from stream %s at position %s before previous position %s", + stream_name, + cmd.token, + current_token, + ) + else: + await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -329,78 +478,74 @@ class ReplicationCommandHandler: stream_name, instance_name, token, rows ) - async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): + def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): if cmd.instance_name == self._instance_name: # Ignore POSITION that are just our own echoes return logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line()) - stream_name = cmd.stream_name - stream = self._streams.get(stream_name) - if not stream: - logger.error("Got POSITION for unknown stream: %s", stream_name) - return + self._add_command_to_stream_queue(conn, cmd) - # We protect catching up with a linearizer in case the replication - # connection reconnects under us. - with await self._position_linearizer.queue(stream_name): - # We're about to go and catch up with the stream, so remove from set - # of connected streams. - for streams in self._streams_by_connection.values(): - streams.discard(stream_name) - - # We clear the pending batches for the stream as the fetching of the - # missing updates below will fetch all rows in the batch. - self._pending_batches.pop(stream_name, []) - - # Find where we previously streamed up to. - current_token = stream.current_token(cmd.instance_name) - - # If the position token matches our current token then we're up to - # date and there's nothing to do. Otherwise, fetch all updates - # between then and now. - missing_updates = cmd.token != current_token - while missing_updates: - logger.info( - "Fetching replication rows for '%s' between %i and %i", - stream_name, - current_token, - cmd.token, - ) - ( - updates, - current_token, - missing_updates, - ) = await stream.get_updates_since( - cmd.instance_name, current_token, cmd.token - ) + async def _process_position( + self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand + ) -> None: + """Process a POSITION command - # TODO: add some tests for this + Called after the command has been popped off the queue of inbound commands + """ + stream = self._streams[stream_name] - # Some streams return multiple rows with the same stream IDs, - # which need to be processed in batches. + # We're about to go and catch up with the stream, so remove from set + # of connected streams. + for streams in self._streams_by_connection.values(): + streams.discard(stream_name) - for token, rows in _batch_updates(updates): - await self.on_rdata( - stream_name, - cmd.instance_name, - token, - [stream.parse_row(row) for row in rows], - ) + # We clear the pending batches for the stream as the fetching of the + # missing updates below will fetch all rows in the batch. + self._pending_batches.pop(stream_name, []) - logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + # Find where we previously streamed up to. + current_token = stream.current_token(cmd.instance_name) - # We've now caught up to position sent to us, notify handler. - await self._replication_data_handler.on_position( - cmd.stream_name, cmd.instance_name, cmd.token + # If the position token matches our current token then we're up to + # date and there's nothing to do. Otherwise, fetch all updates + # between then and now. + missing_updates = cmd.token != current_token + while missing_updates: + logger.info( + "Fetching replication rows for '%s' between %i and %i", + stream_name, + current_token, + cmd.token, + ) + (updates, current_token, missing_updates) = await stream.get_updates_since( + cmd.instance_name, current_token, cmd.token ) - self._streams_by_connection.setdefault(conn, set()).add(stream_name) + # TODO: add some tests for this - async def on_REMOTE_SERVER_UP( - self, conn: AbstractConnection, cmd: RemoteServerUpCommand - ): + # Some streams return multiple rows with the same stream IDs, + # which need to be processed in batches. + + for token, rows in _batch_updates(updates): + await self.on_rdata( + stream_name, + cmd.instance_name, + token, + [stream.parse_row(row) for row in rows], + ) + + logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) + + # We've now caught up to position sent to us, notify handler. + await self._replication_data_handler.on_position( + cmd.stream_name, cmd.instance_name, cmd.token + ) + + self._streams_by_connection.setdefault(conn, set()).add(stream_name) + + def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand): """"Called when get a new REMOTE_SERVER_UP command.""" self._replication_data_handler.on_remote_server_up(cmd.data) @@ -505,7 +650,7 @@ class ReplicationCommandHandler: """Ack data for the federation stream. This allows the master to drop data stored purely in memory. """ - self.send_command(FederationAckCommand(token)) + self.send_command(FederationAckCommand(self._instance_name, token)) def send_user_sync( self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..0b0d204e64 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc import fcntl import logging import struct +from inspect import isawaitable from typing import TYPE_CHECKING, List from prometheus_client import Counter @@ -57,8 +58,12 @@ from prometheus_client import Counter from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure +from synapse.logging.context import PreserveLoggingContext from synapse.metrics import LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS, @@ -108,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5 PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER -class ConnectionStates(object): +class ConnectionStates: CONNECTING = "connecting" ESTABLISHED = "established" PAUSED = "paused" @@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`. + `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine; + if so, that will get run as a background process. It also sends `PING` periodically, and correctly times out remote connections (if they send a `PING` command) @@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # The LoopingCall for sending pings. self._send_ping_loop = None + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + ctx_name = "replication-conn-%s" % self.conn_id + self._logging_context = BackgroundProcessLoggingContext(ctx_name) + self._logging_context.request = ctx_name + def connectionMade(self): logger.info("[%s] Connection established", self.id()) @@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def lineReceived(self, line: bytes): """Called when we've received a line """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_line(line) + + def _parse_and_dispatch_line(self, line: bytes): if line.strip() == "": # Ignore blank lines return @@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. First calls `self.on_<COMMAND>` if it exists, then calls - `self.command_handler.on_<COMMAND>` if it exists. This allows for - protocol level handling of commands (e.g. PINGs), before delegating to - the handler. + `self.command_handler.on_<COMMAND>` if it exists (which can optionally + return an Awaitable). + + This allows for protocol level handling of commands (e.g. PINGs), before + delegating to the handler. Args: cmd: received command @@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # specific handling. cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(cmd) + cmd_func(cmd) handled = True # Then call out to the handler. cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None) if cmd_func: - await cmd_func(self, cmd) + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) + handled = True if not handled: @@ -317,7 +342,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def _queue_command(self, cmd): """Queue the command until the connection is ready to write to again. """ - logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd) + logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) self.pending_commands.append(cmd) if len(self.pending_commands) > self.max_line_buffer: @@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): for cmd in pending: self.send_command(cmd) - async def on_PING(self, line): + def on_PING(self, line): self.received_ping = True - async def on_ERROR(self, cmd): + def on_ERROR(self, cmd): logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) def pauseProducing(self): @@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if self.transport: self.transport.unregisterProducer() + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def __str__(self): addr = None if self.transport: @@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): self.send_command(ServerCommand(self.server_name)) super().connectionMade() - async def on_NAME(self, cmd): + def on_NAME(self, cmd): logger.info("[%s] Renamed to %r", self.id(), cmd.data) self.name = cmd.data @@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # Once we've connected subscribe to the necessary streams self.replicate() - async def on_SERVER(self, cmd): + def on_SERVER(self, cmd): if cmd.data != self.server_name: logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) self.send_error("Wrong remote") diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..f225e533de 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -14,12 +14,16 @@ # limitations under the License. import logging +from inspect import isawaitable from typing import TYPE_CHECKING import txredisapi -from synapse.logging.context import make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics.background_process_metrics import ( + BackgroundProcessLoggingContext, + run_as_background_process, +) from synapse.replication.tcp.commands import ( Command, ReplicateCommand, @@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): stream_name = None # type: str outbound_redis_connection = None # type: txredisapi.RedisProtocol + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + self._logging_context = BackgroundProcessLoggingContext( + "replication_command_handler" + ) + def connectionMade(self): logger.info("Connected to redis") super().connectionMade() @@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. """ + with PreserveLoggingContext(self._logging_context): + self._parse_and_dispatch_message(message) + def _parse_and_dispatch_message(self, message: str): if message.strip() == "": # Ignore blank lines return @@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # remote instances. tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() - # Now lets try and call on_<CMD_NAME> function - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd - ) + self.handle_command(cmd) - async def handle_command(self, cmd: Command): + def handle_command(self, cmd: Command) -> None: """Handle a command we have received over the replication stream. - By default delegates to on_<COMMAND>, which should return an awaitable. + Delegates to `self.handler.on_<COMMAND>` (which can optionally return an + Awaitable). Args: cmd: received command """ - handled = False - - # First call any command handlers on this instance. These are for redis - # specific handling. - cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(cmd) - handled = True - # Then call out to the handler. cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) - if cmd_func: - await cmd_func(self, cmd) - handled = True - - if not handled: + if not cmd_func: logger.warning("Unhandled command: %r", cmd) + return + + res = cmd_func(self, cmd) + + # the handler might be a coroutine: fire it off as a background process + # if so. + + if isawaitable(res): + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res + ) def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) self.handler.lost_connection(self) + # mark the logging context as finished + self._logging_context.__exit__(None, None, None) + def send_command(self, cmd: Command): """Send a command if connection has been established. @@ -177,7 +192,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): Args: hs outbound_redis_connection: A connection to redis that will be used to - send outbound commands (this is seperate to the redis connection + send outbound commands (this is separate to the redis connection used to subscribe). """ diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 41569305df..04d894fb3d 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py
@@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory): ) -class ReplicationStreamer(object): +class ReplicationStreamer: """Handles replication connections. This needs to be poked when new replication data may be available. When new diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4acefc8a96..682d47f402 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py
@@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]] -class Stream(object): +class Stream: """Base class for the streams. Provides a `get_updates()` function that returns new updates since the last @@ -198,26 +198,6 @@ def current_token_without_instance( return lambda instance_name: current_token() -def db_query_to_update_function( - query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] -) -> UpdateFunction: - """Wraps a db query function which returns a list of rows to make it - suitable for use as an `update_function` for the Stream class - """ - - async def update_function(instance_name, from_token, upto_token, limit): - rows = await query_function(from_token, upto_token, limit) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - - return update_function - - def make_http_update_function(hs, stream_name: str) -> UpdateFunction: """Makes a suitable function for use as an `update_function` that queries the master process for updates. @@ -264,7 +244,7 @@ class BackfillStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_backfill_token), - db_query_to_update_function(store.get_all_new_backfill_event_rows), + store.get_all_new_backfill_event_rows, ) @@ -291,9 +271,7 @@ class PresenceStream(Stream): if hs.config.worker_app is None: # on the master, query the presence handler presence_handler = hs.get_presence_handler() - update_function = db_query_to_update_function( - presence_handler.get_all_presence_updates - ) + update_function = presence_handler.get_all_presence_updates else: # Query master process update_function = make_http_update_function(hs, self.NAME) @@ -316,13 +294,12 @@ class TypingStream(Stream): def __init__(self, hs): typing_handler = hs.get_typing_handler() - if hs.config.worker_app is None: - # on the master, query the typing handler - update_function = db_query_to_update_function( - typing_handler.get_all_typing_updates - ) + writer_instance = hs.config.worker.writers.typing + if writer_instance == hs.get_instance_name(): + # On the writer, query the typing handler + update_function = typing_handler.get_all_typing_updates else: - # Query master process + # Query the typing writer process update_function = make_http_update_function(hs, self.NAME) super().__init__( @@ -352,7 +329,7 @@ class ReceiptsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_receipt_stream_id), - db_query_to_update_function(store.get_all_updated_receipts), + store.get_all_updated_receipts, ) @@ -367,26 +344,17 @@ class PushRulesStream(Stream): def __init__(self, hs): self.store = hs.get_datastore() + super(PushRulesStream, self).__init__( - hs.get_instance_name(), self._current_token, self._update_function + hs.get_instance_name(), + self._current_token, + self.store.get_all_push_rule_updates, ) def _current_token(self, instance_name: str) -> int: - push_rules_token, _ = self.store.get_push_rules_stream_token() + push_rules_token = self.store.get_max_push_rules_stream_id() return push_rules_token - async def _update_function( - self, instance_name: str, from_token: Token, to_token: Token, limit: int - ): - rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) - - limited = False - if len(rows) == limit: - to_token = rows[-1][0] - limited = True - - return [(row[0], (row[2],)) for row in rows], to_token, limited - class PushersStream(Stream): """A user has added/changed/removed a pusher @@ -406,7 +374,7 @@ class PushersStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_pushers_stream_token), - db_query_to_update_function(store.get_all_updated_pushers_rows), + store.get_all_updated_pushers_rows, ) @@ -434,27 +402,13 @@ class CachesStream(Stream): ROW_TYPE = CachesStreamRow def __init__(self, hs): - self.store = hs.get_datastore() + store = hs.get_datastore() super().__init__( hs.get_instance_name(), - self.store.get_cache_stream_token, - self._update_function, + store.get_cache_stream_token_for_writer, + store.get_all_updated_caches, ) - async def _update_function( - self, instance_name: str, from_token: int, upto_token: int, limit: int - ): - rows = await self.store.get_all_updated_caches( - instance_name, from_token, upto_token, limit - ) - updates = [(row[0], row[1:]) for row in rows] - limited = False - if len(updates) >= limit: - upto_token = updates[-1][0] - limited = True - - return updates, upto_token, limited - class PublicRoomsStream(Stream): """The public rooms list changed @@ -478,7 +432,7 @@ class PublicRoomsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_current_public_room_stream_id), - db_query_to_update_function(store.get_all_new_public_rooms), + store.get_all_new_public_rooms, ) @@ -499,7 +453,7 @@ class DeviceListsStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function(store.get_all_device_list_changes_for_remotes), + store.get_all_device_list_changes_for_remotes, ) @@ -517,7 +471,7 @@ class ToDeviceStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_to_device_stream_token), - db_query_to_update_function(store.get_all_new_device_messages), + store.get_all_new_device_messages, ) @@ -537,7 +491,7 @@ class TagAccountDataStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_max_account_data_stream_id), - db_query_to_update_function(store.get_all_updated_tags), + store.get_all_updated_tags, ) @@ -625,7 +579,7 @@ class GroupServerStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_group_stream_token), - db_query_to_update_function(store.get_all_groups_changes), + store.get_all_groups_changes, ) @@ -643,7 +597,5 @@ class UserSignatureStream(Stream): super().__init__( hs.get_instance_name(), current_token_without_instance(store.get_device_stream_token), - db_query_to_update_function( - store.get_all_user_signature_changes_for_remotes - ), + store.get_all_user_signature_changes_for_remotes, ) diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f370390331..f929fc3954 100644 --- a/synapse/replication/tcp/streams/events.py +++ b/synapse/replication/tcp/streams/events.py
@@ -13,16 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import heapq -from collections import Iterable +from collections.abc import Iterable from typing import List, Tuple, Type import attr from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance - """Handling of the 'events' replication stream This stream contains rows of various types. Each row therefore contains a 'type' @@ -51,20 +49,20 @@ data part are: @attr.s(slots=True, frozen=True) -class EventsStreamRow(object): +class EventsStreamRow: """A parsed row from the events replication stream""" type = attr.ib() # str: the TypeId of one of the *EventsStreamRows data = attr.ib() # BaseEventsStreamRow -class BaseEventsStreamRow(object): +class BaseEventsStreamRow: """Base class for rows to be sent in the events stream. Specifies how to identify, serialize and deserialize the different types. """ - # Unique string that ids the type. Must be overriden in sub classes. + # Unique string that ids the type. Must be overridden in sub classes. TypeId = None # type: str @classmethod