diff options
Diffstat (limited to 'synapse/replication')
-rw-r--r-- | synapse/replication/http/__init__.py | 2 | ||||
-rw-r--r-- | synapse/replication/http/_base.py | 49 | ||||
-rw-r--r-- | synapse/replication/http/account_data.py | 187 | ||||
-rw-r--r-- | synapse/replication/http/login.py | 12 | ||||
-rw-r--r-- | synapse/replication/slave/storage/_base.py | 10 | ||||
-rw-r--r-- | synapse/replication/slave/storage/_slaved_id_tracker.py | 20 | ||||
-rw-r--r-- | synapse/replication/slave/storage/account_data.py | 40 | ||||
-rw-r--r-- | synapse/replication/slave/storage/deviceinbox.py | 40 | ||||
-rw-r--r-- | synapse/replication/slave/storage/pushers.py | 17 | ||||
-rw-r--r-- | synapse/replication/slave/storage/receipts.py | 35 | ||||
-rw-r--r-- | synapse/replication/tcp/external_cache.py | 105 | ||||
-rw-r--r-- | synapse/replication/tcp/handler.py | 49 | ||||
-rw-r--r-- | synapse/replication/tcp/protocol.py | 3 | ||||
-rw-r--r-- | synapse/replication/tcp/redis.py | 143 |
14 files changed, 510 insertions, 202 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a84a064c8d..dd527e807f 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -15,6 +15,7 @@ from synapse.http.server import JsonResource from synapse.replication.http import ( + account_data, devices, federation, login, @@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource): presence.register_servlets(hs, self) membership.register_servlets(hs, self) streams.register_servlets(hs, self) + account_data.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2b3972cb14..288727a566 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): assert self.METHOD in ("PUT", "POST", "GET") + self._replication_secret = None + if hs.config.worker.worker_replication_secret: + self._replication_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request) -> None: + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if len(auth_headers) > 1: + raise RuntimeError("Too many Authorization headers.") + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._replication_secret == received_secret: + # Success! + return + + raise RuntimeError("Invalid Authorization header.") + @abc.abstractmethod async def _serialize_payload(**kwargs): """Static method that is called when creating a request. @@ -150,9 +169,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + replication_secret = None + if hs.config.worker.worker_replication_secret: + replication_secret = hs.config.worker.worker_replication_secret.encode( + "ascii" + ) + @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() - async def send_request(instance_name="master", **kwargs): + async def send_request(*, instance_name="master", **kwargs): if instance_name == local_instance_name: raise Exception("Trying to send HTTP request to self") if instance_name == "master": @@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # the master, and so whether we should clean up or not. while True: headers = {} # type: Dict[bytes, List[bytes]] + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = await request_func(uri, data, headers=headers) @@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): """ url_args = list(self.PATH_ARGS) - handler = self._handle_request method = self.METHOD if self.CACHE: - handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, + method, [pattern], self._check_auth_and_handle, self.__class__.__name__, ) - def _cached_handler(self, request, txn_id, **kwargs): + def _check_auth_and_handle(self, request, **kwargs): """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We just use the txn_id here, but we probably also want to use the # other PATH_ARGS as well. - assert self.CACHE + # Check the authorization headers before handling the request. + if self._replication_secret: + self._check_auth(request) + + if self.CACHE: + txn_id = kwargs.pop("txn_id") + + return self.response_cache.wrap( + txn_id, self._handle_request, request, **kwargs + ) - return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) + return self._handle_request(request, **kwargs) diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py new file mode 100644 index 0000000000..52d32528ee --- /dev/null +++ b/synapse/replication/http/account_data.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): + """Add user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): + """Add room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "add_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddTagRestServlet(ReplicationEndpoint): + """Add tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_tag/:user_id/:room_id/:tag + + { + "content": { ... }, + } + + """ + + NAME = "add_tag" + PATH_ARGS = ("user_id", "room_id", "tag") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, tag): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_tag_to_room( + user_id, room_id, tag, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRemoveTagRestServlet(ReplicationEndpoint): + """Remove tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag + + {} + + """ + + NAME = "remove_tag" + PATH_ARGS = ( + "user_id", + "room_id", + "tag", + ) + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag): + + return {} + + async def _handle_request(self, request, user_id, room_id, tag): + max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + + return 200, {"max_stream_id": max_stream_id} + + +def register_servlets(hs, http_server): + ReplicationUserAccountDataRestServlet(hs).register(http_server) + ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddTagRestServlet(hs).register(http_server) + ReplicationRemoveTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 4c81e2d784..36071feb36 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): self.registration_handler = hs.get_registration_handler() @staticmethod - async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): + async def _serialize_payload( + user_id, device_id, initial_display_name, is_guest, is_appservice_ghost + ): """ Args: device_id (str|None): Device ID to use, if None a new one is @@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): "device_id": device_id, "initial_display_name": initial_display_name, "is_guest": is_guest, + "is_appservice_ghost": is_appservice_ghost, } async def _handle_request(self, request, user_id): @@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): device_id = content["device_id"] initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] + is_appservice_ghost = content["is_appservice_ghost"] device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest + user_id, + device_id, + initial_display_name, + is_guest, + is_appservice_ghost=is_appservice_ghost, ) return 200, {"device_id": device_id, "access_token": access_token} diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d0089fe06c..693c9ab901 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) # type: Optional[MultiWriterIdGenerator] diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index eb74903d68..0d39a93ed2 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -12,21 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple +from synapse.storage.types import Connection from synapse.storage.util.id_generators import _load_current_id class SlavedIdTracker: - def __init__(self, db_conn, table, column, extra_tables=[], step=1): + def __init__( + self, + db_conn: Connection, + table: str, + column: str, + extra_tables: Optional[List[Tuple[str, str]]] = None, + step: int = 1, + ): self.step = step self._current = _load_current_id(db_conn, table, column, step) - for table, column in extra_tables: - self.advance(None, _load_current_id(db_conn, table, column)) + if extra_tables: + for table, column in extra_tables: + self.advance(None, _load_current_id(db_conn, table, column)) - def advance(self, instance_name, new_id): + def advance(self, instance_name: Optional[str], new_id: int): self._current = (max if self.step > 0 else min)(self._current, new_id) - def get_current_token(self): + def get_current_token(self) -> int: """ Returns: diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 4268565fc8..21afe5f155 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -15,47 +15,9 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "account_data", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self): - return self._account_data_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - if not row.room_id: - self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id) - ) - self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) - self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type) - ) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 5b045bed02..1260f6d141 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -14,46 +14,8 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import ToDeviceStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore -from synapse.util.caches.expiringcache import ExpiringCache -from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_inbox", "stream_id" - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - - self._last_device_delete_cache = ExpiringCache( - cache_name="last_device_delete_cache", - clock=self._clock, - max_len=10000, - expiry_ms=30 * 60 * 1000, - ) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(instance_name, token) - for row in rows: - if row.entity.startswith("@"): - self._device_inbox_stream_cache.entity_has_changed( - row.entity, token - ) - else: - self._device_federation_outbox_stream_cache.entity_has_changed( - row.entity, token - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index c418730ba8..045bd014da 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -13,26 +13,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING from synapse.replication.tcp.streams import PushersStream from synapse.storage.database import DatabasePool from synapse.storage.databases.main.pusher import PusherWorkerStore +from synapse.storage.types import Connection from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._pushers_id_gen = SlavedIdTracker( + self._pushers_id_gen = SlavedIdTracker( # type: ignore db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] ) - def get_pushers_stream_token(self): + def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token, rows + ) -> None: if stream_name == PushersStream.NAME: - self._pushers_id_gen.advance(instance_name, token) + self._pushers_id_gen.advance(instance_name, token) # type: ignore return super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6195917376..3dfdd9961d 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,43 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) - - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): - self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) - ) - self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) - self.get_receipts_for_room.invalidate((room_id, receipt_type)) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(instance_name, token) - for row in rows: - self.invalidate_caches_for_receipt( - row.room_id, row.receipt_type, row.user_id - ) - self._receipts_stream_cache.entity_has_changed(row.room_id, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644 index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 95e5502bf2..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -15,6 +15,7 @@ # limitations under the License. import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Dict, @@ -51,14 +52,21 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + AccountDataStream, BackfillStream, CachesStream, EventsStream, FederationStream, + ReceiptsStream, Stream, + TagAccountDataStream, + ToDeviceStream, TypingStream, ) +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -84,7 +92,7 @@ class ReplicationCommandHandler: back out to connections. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastore() @@ -115,6 +123,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, ToDeviceStream): + # Only add ToDeviceStream as a source on instances in charge of + # sending to device messages. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. @@ -123,6 +139,22 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + # Only add AccountDataStream and TagAccountDataStream as a source on the + # instance in charge of account_data persistence. + if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._streams_to_replicate.append(stream) + + continue + + if isinstance(stream, ReceiptsStream): + # Only add ReceiptsStream as a source on the instance in charge of + # receipts. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue @@ -254,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -271,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - reactor=hs.get_reactor(), - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index a509e599c2..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index bc6ba709a7..fdd087683b 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -15,7 +15,7 @@ import logging from inspect import isawaitable -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Type, cast import txredisapi @@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, + wrap_as_background_process, ) from synapse.replication.tcp.commands import ( Command, @@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): immediately after initialisation. Attributes: - handler: The command handler to handle incoming commands. - stream_name: The *redis* stream name to subscribe to and publish from - (not anything to do with Synapse replication streams). - outbound_redis_connection: The connection to redis to use to send + synapse_handler: The command handler to handle incoming commands. + synapse_stream_name: The *redis* stream name to subscribe to and publish + from (not anything to do with Synapse replication streams). + synapse_outbound_redis_connection: The connection to redis to use to send commands. """ - handler = None # type: ReplicationCommandHandler - stream_name = None # type: str - outbound_redis_connection = None # type: txredisapi.RedisProtocol + synapse_handler = None # type: ReplicationCommandHandler + synapse_stream_name = None # type: str + synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): # it's important to make sure that we only send the REPLICATE command once we # have successfully subscribed to the stream - otherwise we might miss the # POSITION response sent back by the other end. - logger.info("Sending redis SUBSCRIBE for %s", self.stream_name) - await make_deferred_yieldable(self.subscribe(self.stream_name)) + logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name) + await make_deferred_yieldable(self.subscribe(self.synapse_stream_name)) logger.info( "Successfully subscribed to redis stream, sending REPLICATE command" ) - self.handler.new_connection(self) + self.synapse_handler.new_connection(self) await self._async_send_command(ReplicateCommand()) logger.info("REPLICATE successfully sent") # We send out our positions when there is a new connection in case the # other side missed updates. We do this for Redis connections as the # otherside won't know we've connected and so won't issue a REPLICATE. - self.handler.send_positions_to_connection(self) + self.synapse_handler.send_positions_to_connection(self) def messageReceived(self, pattern: str, channel: str, message: str): """Received a message from redis. @@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): cmd: received command """ - cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None) if not cmd_func: logger.warning("Unhandled command: %r", cmd) return @@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): def connectionLost(self, reason): logger.info("Lost connection to redis") super().connectionLost(reason) - self.handler.lost_connection(self) + self.synapse_handler.lost_connection(self) # mark the logging context as finished self._logging_context.__exit__(None, None, None) @@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() await make_deferred_yieldable( - self.outbound_redis_connection.publish(self.stream_name, encoded_string) + self.synapse_outbound_redis_connection.publish( + self.synapse_stream_name, encoded_string + ) + ) + + +class SynapseRedisFactory(txredisapi.RedisFactory): + """A subclass of RedisFactory that periodically sends pings to ensure that + we detect dead connections. + """ + + def __init__( + self, + hs: "HomeServer", + uuid: str, + dbid: Optional[int], + poolsize: int, + isLazy: bool = False, + handler: Type = txredisapi.ConnectionHandler, + charset: str = "utf-8", + password: Optional[str] = None, + replyTimeout: int = 30, + convertNumbers: Optional[int] = True, + ): + super().__init__( + uuid=uuid, + dbid=dbid, + poolsize=poolsize, + isLazy=isLazy, + handler=handler, + charset=charset, + password=password, + replyTimeout=replyTimeout, + convertNumbers=convertNumbers, ) + hs.get_clock().looping_call(self._send_ping, 30 * 1000) + + @wrap_as_background_process("redis_ping") + async def _send_ping(self): + for connection in self.pool: + try: + await make_deferred_yieldable(connection.ping()) + except Exception: + logger.warning("Failed to send ping to a redis connection") -class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): + +class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): """This is a reconnecting factory that connects to redis and immediately subscribes to a stream. @@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol ): - super().__init__() - - # This sets the password on the RedisFactory base class (as - # SubscriberFactory constructor doesn't pass it through). - self.password = hs.config.redis.redis_password + super().__init__( + hs, + uuid="subscriber", + dbid=None, + poolsize=1, + replyTimeout=30, + password=hs.config.redis.redis_password, + ) - self.handler = hs.get_tcp_replication() - self.stream_name = hs.hostname + self.synapse_handler = hs.get_tcp_replication() + self.synapse_stream_name = hs.hostname - self.outbound_redis_connection = outbound_redis_connection + self.synapse_outbound_redis_connection = outbound_redis_connection def buildProtocol(self, addr): - p = super().buildProtocol(addr) # type: RedisSubscriber + p = super().buildProtocol(addr) + p = cast(RedisSubscriber, p) # We do this here rather than add to the constructor of `RedisSubcriber` # as to do so would involve overriding `buildProtocol` entirely, however # the base method does some other things than just instantiating the # protocol. - p.handler = self.handler - p.outbound_redis_connection = self.outbound_redis_connection - p.stream_name = self.stream_name - p.password = self.password + p.synapse_handler = self.synapse_handler + p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection + p.synapse_stream_name = self.synapse_stream_name return p def lazyConnection( - reactor, + hs: "HomeServer", host: str = "localhost", port: int = 6379, dbid: Optional[int] = None, reconnect: bool = True, - charset: str = "utf-8", password: Optional[str] = None, - connectTimeout: Optional[int] = None, - replyTimeout: Optional[int] = None, - convertNumbers: bool = True, + replyTimeout: int = 30, ) -> txredisapi.RedisProtocol: - """Equivalent to `txredisapi.lazyConnection`, except allows specifying a - reactor. + """Creates a connection to Redis that is lazily set up and reconnects if the + connections is lost. """ - isLazy = True - poolsize = 1 - uuid = "%s:%d" % (host, port) - factory = txredisapi.RedisFactory( - uuid, - dbid, - poolsize, - isLazy, - txredisapi.ConnectionHandler, - charset, - password, - replyTimeout, - convertNumbers, + factory = SynapseRedisFactory( + hs, + uuid=uuid, + dbid=dbid, + poolsize=1, + isLazy=True, + handler=txredisapi.ConnectionHandler, + password=password, + replyTimeout=replyTimeout, ) factory.continueTrying = reconnect - for x in range(poolsize): - reactor.connectTCP(host, port, factory, connectTimeout) + + reactor = hs.get_reactor() + reactor.connectTCP(host, port, factory, 30) return factory.handler |