diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 19b69e0e11..a84a064c8d 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
- JsonResource.__init__(self, hs, canonical_json=False)
+ # We enable extracting jaeger contexts here as these are internal APIs.
+ super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs):
@@ -38,10 +39,10 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)
+ streams.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
- streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 793cef6c26..ba16f22c91 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,32 +16,24 @@
import abc
import logging
import re
+import urllib
from inspect import signature
from typing import Dict, List, Tuple
-from six import raise_from
-from six.moves import urllib
-
-from twisted.internet import defer
-
from synapse.api.errors import (
CodeMessageException,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
-from synapse.logging.opentracing import (
- inject_active_span_byte_dict,
- trace,
- trace_servlet,
-)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
-class ReplicationEndpoint(object):
+class ReplicationEndpoint:
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -98,16 +90,16 @@ class ReplicationEndpoint(object):
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
- ), "`instance_name` is a reserved paramater name"
+ ), "`instance_name` is a reserved parameter name"
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
- def _serialize_payload(**kwargs):
+ async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
Concrete implementations should have explicit parameters (rather than
@@ -116,9 +108,8 @@ class ReplicationEndpoint(object):
argument list.
Returns:
- Deferred[dict]|dict: If POST/PUT request then dictionary must be
- JSON serialisable, otherwise must be appropriate for adding as
- query args.
+ dict: If POST/PUT request then dictionary must be JSON serialisable,
+ otherwise must be appropriate for adding as query args.
"""
return {}
@@ -150,8 +141,7 @@ class ReplicationEndpoint(object):
instance_map = hs.config.worker.instance_map
@trace(opname="outgoing_replication_request")
- @defer.inlineCallbacks
- def send_request(instance_name="master", **kwargs):
+ async def send_request(instance_name="master", **kwargs):
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
@@ -165,7 +155,7 @@ class ReplicationEndpoint(object):
"Instance %r not in 'instance_map' config" % (instance_name,)
)
- data = yield cls._serialize_payload(**kwargs)
+ data = await cls._serialize_payload(**kwargs)
url_args = [
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
@@ -203,7 +193,7 @@ class ReplicationEndpoint(object):
headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
- result = yield request_func(uri, data, headers=headers)
+ result = await request_func(uri, data, headers=headers)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@@ -213,14 +203,14 @@ class ReplicationEndpoint(object):
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
- yield clock.sleep(1)
+ await clock.sleep(1)
except HttpResponseException as e:
# We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
except RequestSendFailed as e:
- raise_from(SynapseError(502, "Failed to talk to master"), e)
+ raise SynapseError(502, "Failed to talk to master") from e
return result
@@ -242,11 +232,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
- # We don't let register paths trace this servlet using the default tracing
- # options because we wish to extract the context explicitly.
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__, trace=False
+ method, [pattern], handler, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
index e32aac0a25..20f3ba76c0 100644
--- a/synapse/replication/http/devices.py
+++ b/synapse/replication/http/devices.py
@@ -60,7 +60,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(user_id):
+ async def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index c287c4e269..6b56315148 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -67,8 +65,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
- @defer.inlineCallbacks
- def _serialize_payload(store, event_and_contexts, backfilled):
+ async def _serialize_payload(store, event_and_contexts, backfilled):
"""
Args:
store
@@ -78,7 +75,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""
event_payloads = []
for event, context in event_and_contexts:
- serialized_context = yield context.serialize(event, store)
+ serialized_context = await context.serialize(event, store)
event_payloads.append(
{
@@ -154,7 +151,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
- def _serialize_payload(edu_type, origin, content):
+ async def _serialize_payload(edu_type, origin, content):
return {"origin": origin, "content": content}
async def _handle_request(self, request, edu_type):
@@ -197,7 +194,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
self.registry = hs.get_federation_registry()
@staticmethod
- def _serialize_payload(query_type, args):
+ async def _serialize_payload(query_type, args):
"""
Args:
query_type (str)
@@ -238,7 +235,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
- def _serialize_payload(room_id, args):
+ async def _serialize_payload(room_id, args):
"""
Args:
room_id (str)
@@ -273,7 +270,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
self.store = hs.get_datastore()
@staticmethod
- def _serialize_payload(room_id, room_version):
+ async def _serialize_payload(room_id, room_version):
return {"room_version": room_version.identifier}
async def _handle_request(self, request, room_id):
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 798b9d3af5..fb326bb869 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+ async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index a7174c4a8f..741329ab5f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,11 +14,11 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
-from synapse.types import Requester, UserID
+from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
@@ -52,7 +52,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ async def _serialize_payload(
+ requester, room_id, user_id, remote_room_hosts, content
+ ):
"""
Args:
requester(Requester)
@@ -88,49 +90,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room.
+ """Rejects an out-of-band invite we have received from a remote server
Request format:
- POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
+ POST /_synapse/replication/remote_reject_invite/:event_id
{
+ "txn_id": ...,
"requester": ...,
- "remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
+ PATH_ARGS = ("invite_event_id",)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
- self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
+ async def _serialize_payload( # type: ignore
+ invite_event_id: str,
+ txn_id: Optional[str],
+ requester: Requester,
+ content: JsonDict,
+ ):
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and reject via
+ invite_event_id: ID of the invite to be rejected
+ txn_id: optional transaction ID supplied by the client
+ requester: user making the rejection request, according to the access token
+ content: additional content to include in the rejection event.
+ Normally an empty dict.
"""
return {
+ "txn_id": txn_id,
"requester": requester.serialize(),
- "remote_room_hosts": remote_room_hosts,
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, invite_event_id):
content = parse_json_object_from_request(request)
- remote_room_hosts = content["remote_room_hosts"]
+ txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -138,60 +145,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
-
- try:
- event, stream_id = await self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, user_id, event_content,
- )
- event_id = event.event_id
- except Exception as e:
- # if we were unable to reject the exception, just mark
- # it as rejected on our end and plough ahead.
- #
- # The 'except' clause is very broad, but we need to
- # capture everything from DNS failures upwards
- #
- logger.warning("Failed to reject invite: %s", e)
-
- stream_id = await self.member_handler.locally_reject_invite(
- user_id, room_id
- )
- event_id = None
+ # hopefully we're now on the master, so this won't recurse!
+ event_id, stream_id = await self.member_handler.remote_reject_invite(
+ invite_event_id, txn_id, requester, event_content,
+ )
return 200, {"event_id": event_id, "stream_id": stream_id}
-class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
- """Rejects the invite for the user and room locally.
-
- Request format:
-
- POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
-
- {}
- """
-
- NAME = "locally_reject_invite"
- PATH_ARGS = ("room_id", "user_id")
-
- def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
-
- self.member_handler = hs.get_room_member_handler()
-
- @staticmethod
- def _serialize_payload(room_id, user_id):
- return {}
-
- async def _handle_request(self, request, room_id, user_id):
- logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
-
- stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
-
- return 200, {"stream_id": stream_id}
-
-
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -215,7 +176,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload(room_id, user_id, change):
"""
Args:
room_id (str)
@@ -245,4 +206,3 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
- ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
index ea1b33331b..bc9aa82cb4 100644
--- a/synapse/replication/http/presence.py
+++ b/synapse/replication/http/presence.py
@@ -50,7 +50,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- def _serialize_payload(user_id):
+ async def _serialize_payload(user_id):
return {}
async def _handle_request(self, request, user_id):
@@ -92,7 +92,7 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
self._presence_handler = hs.get_presence_handler()
@staticmethod
- def _serialize_payload(user_id, state, ignore_status_msg=False):
+ async def _serialize_payload(user_id, state, ignore_status_msg=False):
return {
"state": state,
"ignore_status_msg": ignore_status_msg,
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 0c4aca1291..a02b27474d 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -34,7 +34,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(
+ async def _serialize_payload(
user_id,
password_hash,
was_guest,
@@ -44,6 +44,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin,
user_type,
address,
+ shadow_banned,
):
"""
Args:
@@ -60,6 +61,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
address (str|None): the IP address used to perform the regitration.
+ shadow_banned (bool): Whether to shadow-ban the user
"""
return {
"password_hash": password_hash,
@@ -70,6 +72,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"admin": admin,
"user_type": user_type,
"address": address,
+ "shadow_banned": shadow_banned,
}
async def _handle_request(self, request, user_id):
@@ -87,6 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
admin=content["admin"],
user_type=content["user_type"],
address=content["address"],
+ shadow_banned=content["shadow_banned"],
)
return 200, {}
@@ -105,7 +109,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, auth_result, access_token):
+ async def _serialize_payload(user_id, auth_result, access_token):
"""
Args:
user_id (str): The user ID that consented
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index c981723c1a..f13d452426 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
@@ -62,8 +60,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- @defer.inlineCallbacks
- def _serialize_payload(
+ async def _serialize_payload(
event_id, store, event, context, requester, ratelimit, extra_users
):
"""
@@ -77,7 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event
"""
- serialized_context = yield context.serialize(event, store)
+ serialized_context = await context.serialize(event, store)
payload = {
"event": event.get_pdu_json(),
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index bde97eef32..309159e304 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -54,7 +54,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streams()
@staticmethod
- def _serialize_payload(stream_name, from_token, upto_token):
+ async def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name):
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f9e2533e96..60f2e1245f 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -16,8 +16,8 @@
import logging
from typing import Optional
-from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class BaseSlavedStore(CacheInvalidationWorkerStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator(
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 9d1d173b2f..eb74903d68 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -16,14 +16,14 @@
from synapse.storage.util.id_generators import _load_current_id
-class SlavedIdTracker(object):
+class SlavedIdTracker:
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
- self.advance(_load_current_id(db_conn, table, column))
+ self.advance(None, _load_current_id(db_conn, table, column))
- def advance(self, new_id):
+ def advance(self, instance_name, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
int
"""
return self._current
+
+ def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to
+ `get_current_token`.
+ """
+ return self.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 9db6c62bc7..bb66ba9b80 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,13 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
-from synapse.storage.data_stores.main.tags import TagsWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.account_data import AccountDataWorkerStore
+from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
@@ -39,13 +40,13 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "tag_account_data":
- self._account_data_id_gen.advance(token)
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
- elif stream_name == "account_data":
- self._account_data_id_gen.advance(token)
+ elif stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index a67fbeffb7..0f8d7037bd 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.appservice import (
+from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,
)
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 1a38f53dfb..a6fdedde63 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache
from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+ async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 6e7fd259d4..533d927701 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,17 +15,18 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ToDeviceStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id"
+ db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
@@ -44,8 +45,8 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "to_device":
- self._device_inbox_id_gen.advance(token)
+ if stream_name == ToDeviceStream.NAME:
+ self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 9d8067342f..3b788c9625 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -16,14 +16,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
-from synapse.storage.data_stores.main.devices import DeviceWorkerStore
-from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.devices import DeviceWorkerStore
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs
@@ -48,12 +48,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
+ def get_device_stream_token(self) -> int:
+ return self._device_list_id_gen.get_current_token()
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(token)
+ self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(token)
+ self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 8b9717c46f..1945bcf9a8 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
+from synapse.storage.databases.main.directory import DirectoryWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 1a1a50a24f..da1cc836cf 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,18 +15,18 @@
# limitations under the License.
import logging
-from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
-from synapse.storage.data_stores.main.event_push_actions import (
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
+from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.relations import RelationsWorkerStore
-from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
-from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.data_stores.main.stream import StreamWorkerStore
-from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
-from synapse.storage.database import Database
+from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.relations import RelationsWorkerStore
+from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.storage.databases.main.state import StateGroupWorkerStore
+from synapse.storage.databases.main.stream import StreamWorkerStore
+from synapse.storage.databases.main.user_erasure_store import UserErasureWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -55,11 +55,11 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedEventStore, self).__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index bcb0688954..2562b6fc38 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.filtering import FilteringStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.filtering import FilteringStore
from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 1851e7d525..567b4a5cc1 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -15,13 +15,14 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import GroupServerStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
@@ -38,8 +39,8 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "groups":
- self._group_updates_id_gen.advance(token)
+ if stream_name == GroupServerStream.NAME:
+ self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index 3def367ae9..961579751c 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.keys import KeyStore
+from synapse.storage.databases.main.keys import KeyStore
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
# the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 4e0124842d..025f6f6be8 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
-from synapse.storage.data_stores.main.presence import PresenceStore
-from synapse.storage.database import Database
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.presence import PresenceStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -23,7 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
@@ -42,8 +43,8 @@ class SlavedPresenceStore(BaseSlavedStore):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "presence":
- self._presence_id_gen.advance(token)
+ if stream_name == PresenceStream.NAME:
+ self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 28c508aad3..f85b20a071 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.data_stores.main.profile import ProfileWorkerStore
+from synapse.storage.databases.main.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6adb19463a..de904c943c 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,24 +14,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.replication.tcp.streams import PushRulesStream
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def get_push_rules_stream_token(self):
- return (
- self._push_rules_stream_id_gen.get_current_token(),
- self._stream_id_gen.get_current_token(),
- )
-
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "push_rules":
- self._push_rules_stream_id_gen.advance(token)
+ # We assert this for the benefit of mypy
+ assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
+
+ if stream_name == PushRulesStream.NAME:
+ self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cb78b49acb..9da218bfe8 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,15 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.pusher import PusherWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import PushersStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.pusher import PusherWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "pushers":
- self._pushers_id_gen.advance(token)
+ if stream_name == PushersStream.NAME:
+ self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index be716cc558..5c2986e050 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,23 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import ReceiptsStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
-# So, um, we want to borrow a load of functions intended for reading from
-# a DataStore, but we don't want to take functions that either write to the
-# DataStore or are cached and don't have cache invalidation logic.
-#
-# Rather than write duplicate versions of those functions, or lift them to
-# a common base class, we going to grab the underlying __func__ object from
-# the method descriptor on the DataStore and chuck them into our class.
-
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
@@ -52,8 +45,8 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "receipts":
- self._receipts_id_gen.advance(token)
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 4b8553e250..a40f064e2b 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 8873bf37e5..80ae803ad9 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.room import RoomWorkerStore
-from synapse.storage.database import Database
+from synapse.replication.tcp.streams import PublicRoomsStream
+from synapse.storage.database import DatabasePool
+from synapse.storage.databases.main.room import RoomWorkerStore
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
- def __init__(self, database: Database, db_conn, hs):
+ def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
- if stream_name == "public_rooms":
- self._public_room_id_gen.advance(token)
+ if stream_name == PublicRoomsStream.NAME:
+ self._public_room_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index ac88e6b8c3..2091ac0df6 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.data_stores.main.transactions import TransactionStore
+from synapse.storage.databases.main.transactions import TransactionStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 523a1358d4..1b8718b11d 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -25,7 +25,7 @@ Structure of the module:
* command.py - the definitions of all the valid commands
* protocol.py - the TCP protocol classes
* resource.py - handles streaming stream updates to replications
- * streams/ - the definitons of all the valid streams
+ * streams/ - the definitions of all the valid streams
The general interaction of the classes are:
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index df29732f51..d6ecf5b327 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,7 +14,6 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
-import heapq
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple
@@ -24,6 +23,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams import TypingStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
- from synapse.server import HomeServer
from synapse.replication.tcp.handler import ReplicationCommandHandler
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -104,6 +104,7 @@ class ReplicationDataHandler:
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
+ self._typing_handler = hs.get_typing_handler()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
@@ -127,6 +128,12 @@ class ReplicationDataHandler:
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
+ if stream_name == TypingStream.NAME:
+ self._typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows]
+ )
+
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
@@ -211,9 +218,8 @@ class ReplicationDataHandler:
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
- # We insert into the list using heapq as it is more efficient than
- # pushing then resorting each time.
- heapq.heappush(waiting_list, (position, deferred))
+ waiting_list.append((position, deferred))
+ waiting_list.sort(key=lambda t: t[0])
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index c04f622816..8cd47770c1 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,17 +19,9 @@ allowed to be sent by which side.
"""
import abc
import logging
-import platform
from typing import Tuple, Type
-if platform.python_implementation() == "PyPy":
- import json
-
- _json_encoder = json.JSONEncoder()
-else:
- import simplejson as json # type: ignore[no-redef] # noqa: F821
-
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
+from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
@@ -54,7 +46,7 @@ class Command(metaclass=abc.ABCMeta):
@abc.abstractmethod
def to_line(self) -> str:
- """Serialises the comamnd for the wire. Does not include the command
+ """Serialises the command for the wire. Does not include the command
prefix.
"""
@@ -131,7 +123,7 @@ class RdataCommand(Command):
stream_name,
instance_name,
None if token == "batch" else int(token),
- json.loads(row_json),
+ json_decoder.decode(row_json),
)
def to_line(self):
@@ -140,7 +132,7 @@ class RdataCommand(Command):
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
- _json_encoder.encode(self.row),
+ json_encoder.encode(self.row),
)
)
@@ -149,7 +141,7 @@ class RdataCommand(Command):
class PositionCommand(Command):
- """Sent by the server to tell the client the stream postition without
+ """Sent by the server to tell the client the stream position without
needing to send an RDATA.
Format::
@@ -188,7 +180,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand):
- """Sent by either side as a keep alive. The data is arbitary (often timestamp)
+ """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
"""
NAME = "PING"
@@ -300,20 +292,22 @@ class FederationAckCommand(Command):
Format::
- FEDERATION_ACK <token>
+ FEDERATION_ACK <instance_name> <token>
"""
NAME = "FEDERATION_ACK"
- def __init__(self, token):
+ def __init__(self, instance_name, token):
+ self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
- return cls(int(line))
+ instance_name, token = line.split(" ")
+ return cls(instance_name, int(token))
def to_line(self):
- return str(self.token)
+ return "%s %s" % (self.instance_name, self.token)
class RemovePusherCommand(Command):
@@ -363,7 +357,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
@@ -371,7 +365,7 @@ class UserIpCommand(Command):
return (
self.user_id
+ " "
- + _json_encoder.encode(
+ + json_encoder.encode(
(
self.access_token,
self.ip,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index cbcf46f3ae..1c303f3a46 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -13,15 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+from typing import (
+ Any,
+ Awaitable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from prometheus_client import Counter
+from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
@@ -43,8 +56,8 @@ from synapse.replication.tcp.streams import (
EventsStream,
FederationStream,
Stream,
+ TypingStream,
)
-from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -56,12 +69,16 @@ inbound_rdata_count = Counter(
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
- "synapse_replication_tcp_resource_invalidate_cache", ""
-)
+
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+# the type of the entries in _command_queues_by_stream
+_StreamCommandQueue = Deque[
+ Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
+]
+
+
class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands
back out to connections.
@@ -97,6 +114,14 @@ class ReplicationCommandHandler:
continue
+ if isinstance(stream, TypingStream):
+ # Only add TypingStream as a source on the instance in charge of
+ # typing.
+ if hs.config.worker.writers.typing == hs.get_instance_name():
+ self._streams_to_replicate.append(stream)
+
+ continue
+
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
@@ -108,12 +133,8 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
- self._position_linearizer = Linearizer(
- "replication_position", clock=self._clock
- )
-
- # Map of stream to batched updates. See RdataCommand for info on how
- # batching works.
+ # Map of stream name to batched updates. See RdataCommand for info on
+ # how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections.
@@ -123,9 +144,6 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
- # For each connection, the incoming streams that are coming from that connection
- self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
-
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -133,6 +151,32 @@ class ReplicationCommandHandler:
lambda: len(self._connections),
)
+ # When POSITION or RDATA commands arrive, we stick them in a queue and process
+ # them in order in a separate background process.
+
+ # the streams which are currently being processed by _unsafe_process_queue
+ self._processing_streams = set() # type: Set[str]
+
+ # for each stream, a queue of commands that are awaiting processing, and the
+ # connection that they arrived on.
+ self._command_queues_by_stream = {
+ stream_name: _StreamCommandQueue() for stream_name in self._streams
+ }
+
+ # For each connection, the incoming stream names that have received a POSITION
+ # from that connection.
+ self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
+
+ LaterGauge(
+ "synapse_replication_tcp_command_queue",
+ "Number of inbound RDATA/POSITION commands queued for processing",
+ ["stream_name"],
+ lambda: {
+ (stream_name,): len(queue)
+ for stream_name, queue in self._command_queues_by_stream.items()
+ },
+ )
+
self._is_master = hs.config.worker_app is None
self._federation_sender = None
@@ -143,15 +187,75 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
+ def _add_command_to_stream_queue(
+ self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
+ ) -> None:
+ """Queue the given received command for processing
+
+ Adds the given command to the per-stream queue, and processes the queue if
+ necessary
+ """
+ stream_name = cmd.stream_name
+ queue = self._command_queues_by_stream.get(stream_name)
+ if queue is None:
+ logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
+ return
+
+ queue.append((cmd, conn))
+
+ # if we're already processing this stream, there's nothing more to do:
+ # the new entry on the queue will get picked up in due course
+ if stream_name in self._processing_streams:
+ return
+
+ # fire off a background process to start processing the queue.
+ run_as_background_process(
+ "process-replication-data", self._unsafe_process_queue, stream_name
+ )
+
+ async def _unsafe_process_queue(self, stream_name: str):
+ """Processes the command queue for the given stream, until it is empty
+
+ Does not check if there is already a thread processing the queue, hence "unsafe"
+ """
+ assert stream_name not in self._processing_streams
+
+ self._processing_streams.add(stream_name)
+ try:
+ queue = self._command_queues_by_stream.get(stream_name)
+ while queue:
+ cmd, conn = queue.popleft()
+ try:
+ await self._process_command(cmd, conn, stream_name)
+ except Exception:
+ logger.exception("Failed to handle command %s", cmd)
+ finally:
+ self._processing_streams.discard(stream_name)
+
+ async def _process_command(
+ self,
+ cmd: Union[PositionCommand, RdataCommand],
+ conn: AbstractConnection,
+ stream_name: str,
+ ) -> None:
+ if isinstance(cmd, PositionCommand):
+ await self._process_position(stream_name, conn, cmd)
+ elif isinstance(cmd, RdataCommand):
+ await self._process_rdata(stream_name, conn, cmd)
+ else:
+ # This shouldn't be possible
+ raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
+
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
if hs.config.redis.redis_enabled:
+ import txredisapi
+
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
)
- import txredisapi
logger.info(
"Connecting to redis (host=%r port=%r)",
@@ -198,7 +302,7 @@ class ReplicationCommandHandler:
"""
return self._streams_to_replicate
- async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+ def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
@@ -217,57 +321,73 @@ class ReplicationCommandHandler:
)
)
- async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+ def on_USER_SYNC(
+ self, conn: AbstractConnection, cmd: UserSyncCommand
+ ) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
if self._is_master:
- await self._presence_handler.update_external_syncs_row(
+ return self._presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
+ else:
+ return None
- async def on_CLEAR_USER_SYNC(
+ def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
- ):
+ ) -> Optional[Awaitable[None]]:
if self._is_master:
- await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+ return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+ else:
+ return None
- async def on_FEDERATION_ACK(
- self, conn: AbstractConnection, cmd: FederationAckCommand
- ):
+ def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self._federation_sender:
- self._federation_sender.federation_ack(cmd.token)
+ self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
- async def on_REMOVE_PUSHER(
+ def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
- ):
+ ) -> Optional[Awaitable[None]]:
remove_pusher_counter.inc()
if self._is_master:
- await self._store.delete_pusher_by_app_id_pushkey_user_id(
- app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
- )
+ return self._handle_remove_pusher(cmd)
+ else:
+ return None
+
+ async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+ await self._store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+ )
- self._notifier.on_new_replication_data()
+ self._notifier.on_new_replication_data()
- async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+ def on_USER_IP(
+ self, conn: AbstractConnection, cmd: UserIpCommand
+ ) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
if self._is_master:
- await self._store.insert_client_ip(
- cmd.user_id,
- cmd.access_token,
- cmd.ip,
- cmd.user_agent,
- cmd.device_id,
- cmd.last_seen,
- )
+ return self._handle_user_ip(cmd)
+ else:
+ return None
+
+ async def _handle_user_ip(self, cmd: UserIpCommand):
+ await self._store.insert_client_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
+ cmd.last_seen,
+ )
- if self._server_notices_sender:
- await self._server_notices_sender.on_user_ip(cmd.user_id)
+ assert self._server_notices_sender is not None
+ await self._server_notices_sender.on_user_ip(cmd.user_id)
- async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+ def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -275,42 +395,71 @@ class ReplicationCommandHandler:
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
- try:
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
- except Exception:
- logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
- raise
-
- # We linearize here for two reasons:
+ # We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
- with await self._position_linearizer.queue(cmd.stream_name):
- # make sure that we've processed a POSITION for this stream *on this
- # connection*. (A POSITION on another connection is no good, as there
- # is no guarantee that we have seen all the intermediate updates.)
- sbc = self._streams_by_connection.get(conn)
- if not sbc or stream_name not in sbc:
- # Let's drop the row for now, on the assumption we'll receive a
- # `POSITION` soon and we'll catch up correctly then.
- logger.debug(
- "Discarding RDATA for unconnected stream %s -> %s",
- stream_name,
- cmd.token,
- )
- return
-
- if cmd.token is None:
- # I.e. this is part of a batch of updates for this stream (in
- # which case batch until we get an update for the stream with a non
- # None token).
- self._pending_batches.setdefault(stream_name, []).append(row)
- else:
- # Check if this is the last of a batch of updates
- rows = self._pending_batches.pop(stream_name, [])
- rows.append(row)
- await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+ self._add_command_to_stream_queue(conn, cmd)
+
+ async def _process_rdata(
+ self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
+ ) -> None:
+ """Process an RDATA command
+
+ Called after the command has been popped off the queue of inbound commands
+ """
+ try:
+ row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+ except Exception as e:
+ raise Exception(
+ "Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
+ ) from e
+
+ # make sure that we've processed a POSITION for this stream *on this
+ # connection*. (A POSITION on another connection is no good, as there
+ # is no guarantee that we have seen all the intermediate updates.)
+ sbc = self._streams_by_connection.get(conn)
+ if not sbc or stream_name not in sbc:
+ # Let's drop the row for now, on the assumption we'll receive a
+ # `POSITION` soon and we'll catch up correctly then.
+ logger.debug(
+ "Discarding RDATA for unconnected stream %s -> %s",
+ stream_name,
+ cmd.token,
+ )
+ return
+
+ if cmd.token is None:
+ # I.e. this is part of a batch of updates for this stream (in
+ # which case batch until we get an update for the stream with a non
+ # None token).
+ self._pending_batches.setdefault(stream_name, []).append(row)
+ return
+
+ # Check if this is the last of a batch of updates
+ rows = self._pending_batches.pop(stream_name, [])
+ rows.append(row)
+
+ stream = self._streams[stream_name]
+
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
+
+ # Discard this data if this token is earlier than the current
+ # position. Note that streams can be reset (in which case you
+ # expect an earlier token), but that must be preceded by a
+ # POSITION command.
+ if cmd.token <= current_token:
+ logger.debug(
+ "Discarding RDATA from stream %s at position %s before previous position %s",
+ stream_name,
+ cmd.token,
+ current_token,
+ )
+ else:
+ await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -329,78 +478,74 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
- async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+ def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
- stream_name = cmd.stream_name
- stream = self._streams.get(stream_name)
- if not stream:
- logger.error("Got POSITION for unknown stream: %s", stream_name)
- return
+ self._add_command_to_stream_queue(conn, cmd)
- # We protect catching up with a linearizer in case the replication
- # connection reconnects under us.
- with await self._position_linearizer.queue(stream_name):
- # We're about to go and catch up with the stream, so remove from set
- # of connected streams.
- for streams in self._streams_by_connection.values():
- streams.discard(stream_name)
-
- # We clear the pending batches for the stream as the fetching of the
- # missing updates below will fetch all rows in the batch.
- self._pending_batches.pop(stream_name, [])
-
- # Find where we previously streamed up to.
- current_token = stream.current_token(cmd.instance_name)
-
- # If the position token matches our current token then we're up to
- # date and there's nothing to do. Otherwise, fetch all updates
- # between then and now.
- missing_updates = cmd.token != current_token
- while missing_updates:
- logger.info(
- "Fetching replication rows for '%s' between %i and %i",
- stream_name,
- current_token,
- cmd.token,
- )
- (
- updates,
- current_token,
- missing_updates,
- ) = await stream.get_updates_since(
- cmd.instance_name, current_token, cmd.token
- )
+ async def _process_position(
+ self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
+ ) -> None:
+ """Process a POSITION command
- # TODO: add some tests for this
+ Called after the command has been popped off the queue of inbound commands
+ """
+ stream = self._streams[stream_name]
- # Some streams return multiple rows with the same stream IDs,
- # which need to be processed in batches.
+ # We're about to go and catch up with the stream, so remove from set
+ # of connected streams.
+ for streams in self._streams_by_connection.values():
+ streams.discard(stream_name)
- for token, rows in _batch_updates(updates):
- await self.on_rdata(
- stream_name,
- cmd.instance_name,
- token,
- [stream.parse_row(row) for row in rows],
- )
+ # We clear the pending batches for the stream as the fetching of the
+ # missing updates below will fetch all rows in the batch.
+ self._pending_batches.pop(stream_name, [])
- logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+ # Find where we previously streamed up to.
+ current_token = stream.current_token(cmd.instance_name)
- # We've now caught up to position sent to us, notify handler.
- await self._replication_data_handler.on_position(
- cmd.stream_name, cmd.instance_name, cmd.token
+ # If the position token matches our current token then we're up to
+ # date and there's nothing to do. Otherwise, fetch all updates
+ # between then and now.
+ missing_updates = cmd.token != current_token
+ while missing_updates:
+ logger.info(
+ "Fetching replication rows for '%s' between %i and %i",
+ stream_name,
+ current_token,
+ cmd.token,
+ )
+ (updates, current_token, missing_updates) = await stream.get_updates_since(
+ cmd.instance_name, current_token, cmd.token
)
- self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+ # TODO: add some tests for this
- async def on_REMOTE_SERVER_UP(
- self, conn: AbstractConnection, cmd: RemoteServerUpCommand
- ):
+ # Some streams return multiple rows with the same stream IDs,
+ # which need to be processed in batches.
+
+ for token, rows in _batch_updates(updates):
+ await self.on_rdata(
+ stream_name,
+ cmd.instance_name,
+ token,
+ [stream.parse_row(row) for row in rows],
+ )
+
+ logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+ # We've now caught up to position sent to us, notify handler.
+ await self._replication_data_handler.on_position(
+ cmd.stream_name, cmd.instance_name, cmd.token
+ )
+
+ self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+
+ def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -505,7 +650,7 @@ class ReplicationCommandHandler:
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
- self.send_command(FederationAckCommand(token))
+ self.send_command(FederationAckCommand(self._instance_name, token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 4198eece71..0b0d204e64 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
import fcntl
import logging
import struct
+from inspect import isawaitable
from typing import TYPE_CHECKING, List
from prometheus_client import Counter
@@ -57,8 +58,12 @@ from prometheus_client import Counter
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import PreserveLoggingContext
from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+ BackgroundProcessLoggingContext,
+ run_as_background_process,
+)
from synapse.replication.tcp.commands import (
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
@@ -108,7 +113,7 @@ PING_TIMEOUT_MULTIPLIER = 5
PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
-class ConnectionStates(object):
+class ConnectionStates:
CONNECTING = "connecting"
ESTABLISHED = "established"
PAUSED = "paused"
@@ -124,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+ `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+ if so, that will get run as a background process.
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
@@ -160,6 +167,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
+ # a logcontext which we use for processing incoming commands. We declare it as a
+ # background process so that the CPU stats get reported to prometheus.
+ ctx_name = "replication-conn-%s" % self.conn_id
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+ self._logging_context.request = ctx_name
+
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -210,6 +223,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def lineReceived(self, line: bytes):
"""Called when we've received a line
"""
+ with PreserveLoggingContext(self._logging_context):
+ self._parse_and_dispatch_line(line)
+
+ def _parse_and_dispatch_line(self, line: bytes):
if line.strip() == "":
# Ignore blank lines
return
@@ -232,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
First calls `self.on_<COMMAND>` if it exists, then calls
- `self.command_handler.on_<COMMAND>` if it exists. This allows for
- protocol level handling of commands (e.g. PINGs), before delegating to
- the handler.
+ `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+ return an Awaitable).
+
+ This allows for protocol level handling of commands (e.g. PINGs), before
+ delegating to the handler.
Args:
cmd: received command
@@ -254,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(cmd)
+ cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
- await cmd_func(self, cmd)
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
+
handled = True
if not handled:
@@ -317,7 +342,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
- logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
+ logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
@@ -336,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- async def on_PING(self, line):
+ def on_PING(self, line):
self.received_ping = True
- async def on_ERROR(self, cmd):
+ def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -397,6 +422,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
self.transport.unregisterProducer()
+ # mark the logging context as finished
+ self._logging_context.__exit__(None, None, None)
+
def __str__(self):
addr = None
if self.transport:
@@ -431,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ServerCommand(self.server_name))
super().connectionMade()
- async def on_NAME(self, cmd):
+ def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
@@ -460,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Once we've connected subscribe to the necessary streams
self.replicate()
- async def on_SERVER(self, cmd):
+ def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index e776b63183..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,12 +14,16 @@
# limitations under the License.
import logging
+from inspect import isawaitable
from typing import TYPE_CHECKING
import txredisapi
-from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import (
+ BackgroundProcessLoggingContext,
+ run_as_background_process,
+)
from synapse.replication.tcp.commands import (
Command,
ReplicateCommand,
@@ -66,6 +70,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # a logcontext which we use for processing incoming commands. We declare it as a
+ # background process so that the CPU stats get reported to prometheus.
+ self._logging_context = BackgroundProcessLoggingContext(
+ "replication_command_handler"
+ )
+
def connectionMade(self):
logger.info("Connected to redis")
super().connectionMade()
@@ -92,7 +105,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
+ with PreserveLoggingContext(self._logging_context):
+ self._parse_and_dispatch_message(message)
+ def _parse_and_dispatch_message(self, message: str):
if message.strip() == "":
# Ignore blank lines
return
@@ -109,42 +125,41 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
- # Now lets try and call on_<CMD_NAME> function
- run_as_background_process(
- "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
- )
+ self.handle_command(cmd)
- async def handle_command(self, cmd: Command):
+ def handle_command(self, cmd: Command) -> None:
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>, which should return an awaitable.
+ Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+ Awaitable).
Args:
cmd: received command
"""
- handled = False
-
- # First call any command handlers on this instance. These are for redis
- # specific handling.
- cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(cmd)
- handled = True
- # Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
- if cmd_func:
- await cmd_func(self, cmd)
- handled = True
-
- if not handled:
+ if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
+ return
+
+ res = cmd_func(self, cmd)
+
+ # the handler might be a coroutine: fire it off as a background process
+ # if so.
+
+ if isawaitable(res):
+ run_as_background_process(
+ "replication-" + cmd.get_logcontext_id(), lambda: res
+ )
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
+ # mark the logging context as finished
+ self._logging_context.__exit__(None, None, None)
+
def send_command(self, cmd: Command):
"""Send a command if connection has been established.
@@ -177,7 +192,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
- send outbound commands (this is seperate to the redis connection
+ send outbound commands (this is separate to the redis connection
used to subscribe).
"""
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 41569305df..04d894fb3d 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -58,7 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
)
-class ReplicationStreamer(object):
+class ReplicationStreamer:
"""Handles replication connections.
This needs to be poked when new replication data may be available. When new
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4acefc8a96..682d47f402 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -79,7 +79,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
-class Stream(object):
+class Stream:
"""Base class for the streams.
Provides a `get_updates()` function that returns new updates since the last
@@ -198,26 +198,6 @@ def current_token_without_instance(
return lambda instance_name: current_token()
-def db_query_to_update_function(
- query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
-) -> UpdateFunction:
- """Wraps a db query function which returns a list of rows to make it
- suitable for use as an `update_function` for the Stream class
- """
-
- async def update_function(instance_name, from_token, upto_token, limit):
- rows = await query_function(from_token, upto_token, limit)
- updates = [(row[0], row[1:]) for row in rows]
- limited = False
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
-
- return update_function
-
-
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
@@ -264,7 +244,7 @@ class BackfillStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_backfill_token),
- db_query_to_update_function(store.get_all_new_backfill_event_rows),
+ store.get_all_new_backfill_event_rows,
)
@@ -291,9 +271,7 @@ class PresenceStream(Stream):
if hs.config.worker_app is None:
# on the master, query the presence handler
presence_handler = hs.get_presence_handler()
- update_function = db_query_to_update_function(
- presence_handler.get_all_presence_updates
- )
+ update_function = presence_handler.get_all_presence_updates
else:
# Query master process
update_function = make_http_update_function(hs, self.NAME)
@@ -316,13 +294,12 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- if hs.config.worker_app is None:
- # on the master, query the typing handler
- update_function = db_query_to_update_function(
- typing_handler.get_all_typing_updates
- )
+ writer_instance = hs.config.worker.writers.typing
+ if writer_instance == hs.get_instance_name():
+ # On the writer, query the typing handler
+ update_function = typing_handler.get_all_typing_updates
else:
- # Query master process
+ # Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
@@ -352,7 +329,7 @@ class ReceiptsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_receipt_stream_id),
- db_query_to_update_function(store.get_all_updated_receipts),
+ store.get_all_updated_receipts,
)
@@ -367,26 +344,17 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
+
super(PushRulesStream, self).__init__(
- hs.get_instance_name(), self._current_token, self._update_function
+ hs.get_instance_name(),
+ self._current_token,
+ self.store.get_all_push_rule_updates,
)
def _current_token(self, instance_name: str) -> int:
- push_rules_token, _ = self.store.get_push_rules_stream_token()
+ push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
- async def _update_function(
- self, instance_name: str, from_token: Token, to_token: Token, limit: int
- ):
- rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-
- limited = False
- if len(rows) == limit:
- to_token = rows[-1][0]
- limited = True
-
- return [(row[0], (row[2],)) for row in rows], to_token, limited
-
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
@@ -406,7 +374,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
- db_query_to_update_function(store.get_all_updated_pushers_rows),
+ store.get_all_updated_pushers_rows,
)
@@ -434,27 +402,13 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
- self.store = hs.get_datastore()
+ store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
- self.store.get_cache_stream_token,
- self._update_function,
+ store.get_cache_stream_token_for_writer,
+ store.get_all_updated_caches,
)
- async def _update_function(
- self, instance_name: str, from_token: int, upto_token: int, limit: int
- ):
- rows = await self.store.get_all_updated_caches(
- instance_name, from_token, upto_token, limit
- )
- updates = [(row[0], row[1:]) for row in rows]
- limited = False
- if len(updates) >= limit:
- upto_token = updates[-1][0]
- limited = True
-
- return updates, upto_token, limited
-
class PublicRoomsStream(Stream):
"""The public rooms list changed
@@ -478,7 +432,7 @@ class PublicRoomsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_public_room_stream_id),
- db_query_to_update_function(store.get_all_new_public_rooms),
+ store.get_all_new_public_rooms,
)
@@ -499,7 +453,7 @@ class DeviceListsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
- db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+ store.get_all_device_list_changes_for_remotes,
)
@@ -517,7 +471,7 @@ class ToDeviceStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
- db_query_to_update_function(store.get_all_new_device_messages),
+ store.get_all_new_device_messages,
)
@@ -537,7 +491,7 @@ class TagAccountDataStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
- db_query_to_update_function(store.get_all_updated_tags),
+ store.get_all_updated_tags,
)
@@ -625,7 +579,7 @@ class GroupServerStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_group_stream_token),
- db_query_to_update_function(store.get_all_groups_changes),
+ store.get_all_groups_changes,
)
@@ -643,7 +597,5 @@ class UserSignatureStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
- db_query_to_update_function(
- store.get_all_user_signature_changes_for_remotes
- ),
+ store.get_all_user_signature_changes_for_remotes,
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f370390331..f929fc3954 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,16 +13,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import heapq
-from collections import Iterable
+from collections.abc import Iterable
from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
-
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -51,20 +49,20 @@ data part are:
@attr.s(slots=True, frozen=True)
-class EventsStreamRow(object):
+class EventsStreamRow:
"""A parsed row from the events replication stream"""
type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
data = attr.ib() # BaseEventsStreamRow
-class BaseEventsStreamRow(object):
+class BaseEventsStreamRow:
"""Base class for rows to be sent in the events stream.
Specifies how to identify, serialize and deserialize the different types.
"""
- # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overridden in sub classes.
TypeId = None # type: str
@classmethod
|