summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/_base.py49
-rw-r--r--synapse/replication/http/account_data.py187
-rw-r--r--synapse/replication/http/login.py12
-rw-r--r--synapse/replication/slave/storage/_base.py10
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py20
-rw-r--r--synapse/replication/slave/storage/account_data.py40
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py40
-rw-r--r--synapse/replication/slave/storage/pushers.py17
-rw-r--r--synapse/replication/slave/storage/receipts.py35
-rw-r--r--synapse/replication/tcp/external_cache.py105
-rw-r--r--synapse/replication/tcp/handler.py49
-rw-r--r--synapse/replication/tcp/protocol.py3
-rw-r--r--synapse/replication/tcp/redis.py143
14 files changed, 510 insertions, 202 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index a84a064c8d..dd527e807f 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -15,6 +15,7 @@
 
 from synapse.http.server import JsonResource
 from synapse.replication.http import (
+    account_data,
     devices,
     federation,
     login,
@@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
         presence.register_servlets(hs, self)
         membership.register_servlets(hs, self)
         streams.register_servlets(hs, self)
+        account_data.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb14..288727a566 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         assert self.METHOD in ("PUT", "POST", "GET")
 
+        self._replication_secret = None
+        if hs.config.worker.worker_replication_secret:
+            self._replication_secret = hs.config.worker.worker_replication_secret
+
+    def _check_auth(self, request) -> None:
+        # Get the authorization header.
+        auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+        if len(auth_headers) > 1:
+            raise RuntimeError("Too many Authorization headers.")
+        parts = auth_headers[0].split(b" ")
+        if parts[0] == b"Bearer" and len(parts) == 2:
+            received_secret = parts[1].decode("ascii")
+            if self._replication_secret == received_secret:
+                # Success!
+                return
+
+        raise RuntimeError("Invalid Authorization header.")
+
     @abc.abstractmethod
     async def _serialize_payload(**kwargs):
         """Static method that is called when creating a request.
@@ -150,9 +169,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
 
+        replication_secret = None
+        if hs.config.worker.worker_replication_secret:
+            replication_secret = hs.config.worker.worker_replication_secret.encode(
+                "ascii"
+            )
+
         @trace(opname="outgoing_replication_request")
         @outgoing_gauge.track_inprogress()
-        async def send_request(instance_name="master", **kwargs):
+        async def send_request(*, instance_name="master", **kwargs):
             if instance_name == local_instance_name:
                 raise Exception("Trying to send HTTP request to self")
             if instance_name == "master":
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                 # the master, and so whether we should clean up or not.
                 while True:
                     headers = {}  # type: Dict[bytes, List[bytes]]
+                    # Add an authorization header, if configured.
+                    if replication_secret:
+                        headers[b"Authorization"] = [b"Bearer " + replication_secret]
                     inject_active_span_byte_dict(headers, None, check_destination=False)
                     try:
                         result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         """
 
         url_args = list(self.PATH_ARGS)
-        handler = self._handle_request
         method = self.METHOD
 
         if self.CACHE:
-            handler = self._cached_handler  # type: ignore
             url_args.append("txn_id")
 
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
         http_server.register_paths(
-            method, [pattern], handler, self.__class__.__name__,
+            method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
         )
 
-    def _cached_handler(self, request, txn_id, **kwargs):
+    def _check_auth_and_handle(self, request, **kwargs):
         """Called on new incoming requests when caching is enabled. Checks
         if there is a cached response for the request and returns that,
         otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         # We just use the txn_id here, but we probably also want to use the
         # other PATH_ARGS as well.
 
-        assert self.CACHE
+        # Check the authorization headers before handling the request.
+        if self._replication_secret:
+            self._check_auth(request)
+
+        if self.CACHE:
+            txn_id = kwargs.pop("txn_id")
+
+            return self.response_cache.wrap(
+                txn_id, self._handle_request, request, **kwargs
+            )
 
-        return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+        return self._handle_request(request, **kwargs)
diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py
new file mode 100644
index 0000000000..52d32528ee
--- /dev/null
+++ b/synapse/replication/http/account_data.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
+    """Add user account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_user_account_data/:user_id/:type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_user_account_data"
+    PATH_ARGS = ("user_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_for_user(
+            user_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
+    """Add room account data on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_room_account_data"
+    PATH_ARGS = ("user_id", "room_id", "account_data_type")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, account_data_type, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, account_data_type):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_account_data_to_room(
+            user_id, room_id, account_data_type, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationAddTagRestServlet(ReplicationEndpoint):
+    """Add tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
+
+        {
+            "content": { ... },
+        }
+
+    """
+
+    NAME = "add_tag"
+    PATH_ARGS = ("user_id", "room_id", "tag")
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag, content):
+        payload = {
+            "content": content,
+        }
+
+        return payload
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        content = parse_json_object_from_request(request)
+
+        max_stream_id = await self.handler.add_tag_to_room(
+            user_id, room_id, tag, content["content"]
+        )
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
+    """Remove tag on the appropriate account data worker.
+
+    Request format:
+
+        POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
+
+        {}
+
+    """
+
+    NAME = "remove_tag"
+    PATH_ARGS = (
+        "user_id",
+        "room_id",
+        "tag",
+    )
+    CACHE = False
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.handler = hs.get_account_data_handler()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_id, room_id, tag):
+
+        return {}
+
+    async def _handle_request(self, request, user_id, room_id, tag):
+        max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
+
+        return 200, {"max_stream_id": max_stream_id}
+
+
+def register_servlets(hs, http_server):
+    ReplicationUserAccountDataRestServlet(hs).register(http_server)
+    ReplicationRoomAccountDataRestServlet(hs).register(http_server)
+    ReplicationAddTagRestServlet(hs).register(http_server)
+    ReplicationRemoveTagRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 4c81e2d784..36071feb36 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
-    async def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
+    async def _serialize_payload(
+        user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
+    ):
         """
         Args:
             device_id (str|None): Device ID to use, if None a new one is
@@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "device_id": device_id,
             "initial_display_name": initial_display_name,
             "is_guest": is_guest,
+            "is_appservice_ghost": is_appservice_ghost,
         }
 
     async def _handle_request(self, request, user_id):
@@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         device_id = content["device_id"]
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]
+        is_appservice_ghost = content["is_appservice_ghost"]
 
         device_id, access_token = await self.registration_handler.register_device(
-            user_id, device_id, initial_display_name, is_guest
+            user_id,
+            device_id,
+            initial_display_name,
+            is_guest,
+            is_appservice_ghost=is_appservice_ghost,
         )
 
         return 200, {"device_id": device_id, "access_token": access_token}
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index d0089fe06c..693c9ab901 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
                 database,
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
-                table="cache_invalidation_stream_by_instance",
-                instance_column="instance_name",
-                id_column="stream_id",
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
                 sequence_name="cache_invalidation_stream_seq",
                 writers=[],
             )  # type: Optional[MultiWriterIdGenerator]
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index eb74903d68..0d39a93ed2 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -12,21 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Optional, Tuple
 
+from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import _load_current_id
 
 
 class SlavedIdTracker:
-    def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+    def __init__(
+        self,
+        db_conn: Connection,
+        table: str,
+        column: str,
+        extra_tables: Optional[List[Tuple[str, str]]] = None,
+        step: int = 1,
+    ):
         self.step = step
         self._current = _load_current_id(db_conn, table, column, step)
-        for table, column in extra_tables:
-            self.advance(None, _load_current_id(db_conn, table, column))
+        if extra_tables:
+            for table, column in extra_tables:
+                self.advance(None, _load_current_id(db_conn, table, column))
 
-    def advance(self, instance_name, new_id):
+    def advance(self, instance_name: Optional[str], new_id: int):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         """
 
         Returns:
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 4268565fc8..21afe5f155 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -15,47 +15,9 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.storage.databases.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        self._account_data_id_gen = SlavedIdTracker(
-            db_conn,
-            "account_data",
-            "stream_id",
-            extra_tables=[
-                ("room_account_data", "stream_id"),
-                ("room_tags_revisions", "stream_id"),
-            ],
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_account_data_stream_id(self):
-        return self._account_data_id_gen.get_current_token()
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.get_tags_for_user.invalidate((row.user_id,))
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-            for row in rows:
-                if not row.room_id:
-                    self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id)
-                    )
-                self.get_account_data_for_user.invalidate((row.user_id,))
-                self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
-                self.get_account_data_for_room_and_type.invalidate(
-                    (row.user_id, row.room_id, row.data_type)
-                )
-                self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 5b045bed02..1260f6d141 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -14,46 +14,8 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.replication.tcp.streams import ToDeviceStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore
-from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-        self._device_inbox_id_gen = SlavedIdTracker(
-            db_conn, "device_inbox", "stream_id"
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            self._device_inbox_id_gen.get_current_token(),
-        )
-
-        self._last_device_delete_cache = ExpiringCache(
-            cache_name="last_device_delete_cache",
-            clock=self._clock,
-            max_len=10000,
-            expiry_ms=30 * 60 * 1000,
-        )
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ToDeviceStream.NAME:
-            self._device_inbox_id_gen.advance(instance_name, token)
-            for row in rows:
-                if row.entity.startswith("@"):
-                    self._device_inbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-                else:
-                    self._device_federation_outbox_stream_cache.entity_has_changed(
-                        row.entity, token
-                    )
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index c418730ba8..045bd014da 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -13,26 +13,33 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import TYPE_CHECKING
 
 from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
-        self._pushers_id_gen = SlavedIdTracker(
+        self._pushers_id_gen = SlavedIdTracker(  # type: ignore
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
-    def get_pushers_stream_token(self):
+    def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token, rows
+    ) -> None:
         if stream_name == PushersStream.NAME:
-            self._pushers_id_gen.advance(instance_name, token)
+            self._pushers_id_gen.advance(instance_name, token)  # type: ignore
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 6195917376..3dfdd9961d 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,43 +14,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.replication.tcp.streams import ReceiptsStream
-from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        # We instantiate this first as the ReceiptsWorkerStore constructor
-        # needs to be able to call get_max_receipt_stream_id
-        self._receipts_id_gen = SlavedIdTracker(
-            db_conn, "receipts_linearized", "stream_id"
-        )
-
-        super().__init__(database, db_conn, hs)
-
-    def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_current_token()
-
-    def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
-        self._get_linearized_receipts_for_room.invalidate_many((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
-        )
-        self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
-        self.get_receipts_for_room.invalidate((room_id, receipt_type))
-
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
-        if stream_name == ReceiptsStream.NAME:
-            self._receipts_id_gen.advance(instance_name, token)
-            for row in rows:
-                self.invalidate_caches_for_receipt(
-                    row.room_id, row.receipt_type, row.user_id
-                )
-                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
-
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+    pass
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+set_counter = Counter(
+    "synapse_external_cache_set",
+    "Number of times we set a cache",
+    labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+    "synapse_external_cache_get",
+    "Number of times we get a cache",
+    labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+    """A cache backed by an external Redis. Does nothing if no Redis is
+    configured.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._redis_connection = hs.get_outbound_redis_connection()
+
+    def _get_redis_key(self, cache_name: str, key: str) -> str:
+        return "cache_v1:%s:%s" % (cache_name, key)
+
+    def is_enabled(self) -> bool:
+        """Whether the external cache is used or not.
+
+        It's safe to use the cache when this returns false, the methods will
+        just no-op, but the function is useful to avoid doing unnecessary work.
+        """
+        return self._redis_connection is not None
+
+    async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+        """Add the key/value to the named cache, with the expiry time given.
+        """
+
+        if self._redis_connection is None:
+            return
+
+        set_counter.labels(cache_name).inc()
+
+        # txredisapi requires the value to be string, bytes or numbers, so we
+        # encode stuff in JSON.
+        encoded_value = json_encoder.encode(value)
+
+        logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+        return await make_deferred_yieldable(
+            self._redis_connection.set(
+                self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+            )
+        )
+
+    async def get(self, cache_name: str, key: str) -> Optional[Any]:
+        """Look up a key/value in the named cache.
+        """
+
+        if self._redis_connection is None:
+            return None
+
+        result = await make_deferred_yieldable(
+            self._redis_connection.get(self._get_redis_key(cache_name, key))
+        )
+
+        logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+        get_counter.labels(cache_name, result is not None).inc()
+
+        if not result:
+            return None
+
+        # For some reason the integers get magically converted back to integers
+        if isinstance(result, int):
+            return result
+
+        return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 95e5502bf2..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 import logging
 from typing import (
+    TYPE_CHECKING,
     Any,
     Awaitable,
     Dict,
@@ -51,14 +52,21 @@ from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.protocol import AbstractConnection
 from synapse.replication.tcp.streams import (
     STREAMS_MAP,
+    AccountDataStream,
     BackfillStream,
     CachesStream,
     EventsStream,
     FederationStream,
+    ReceiptsStream,
     Stream,
+    TagAccountDataStream,
+    ToDeviceStream,
     TypingStream,
 )
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -84,7 +92,7 @@ class ReplicationCommandHandler:
     back out to connections.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self._replication_data_handler = hs.get_replication_data_handler()
         self._presence_handler = hs.get_presence_handler()
         self._store = hs.get_datastore()
@@ -115,6 +123,14 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, ToDeviceStream):
+                # Only add ToDeviceStream as a source on instances in charge of
+                # sending to device messages.
+                if hs.get_instance_name() in hs.config.worker.writers.to_device:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             if isinstance(stream, TypingStream):
                 # Only add TypingStream as a source on the instance in charge of
                 # typing.
@@ -123,6 +139,22 @@ class ReplicationCommandHandler:
 
                 continue
 
+            if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
+                # Only add AccountDataStream and TagAccountDataStream as a source on the
+                # instance in charge of account_data persistence.
+                if hs.get_instance_name() in hs.config.worker.writers.account_data:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
+            if isinstance(stream, ReceiptsStream):
+                # Only add ReceiptsStream as a source on the instance in charge of
+                # receipts.
+                if hs.get_instance_name() in hs.config.worker.writers.receipts:
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
             # Only add any other streams if we're on master.
             if hs.config.worker_app is not None:
                 continue
@@ -254,13 +286,6 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
-                lazyConnection,
-            )
-
-            logger.info(
-                "Connecting to redis (host=%r port=%r)",
-                hs.config.redis_host,
-                hs.config.redis_port,
             )
 
             # First let's ensure that we have a ReplicationStreamer started.
@@ -271,13 +296,7 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = lazyConnection(
-                reactor=hs.get_reactor(),
-                host=hs.config.redis_host,
-                port=hs.config.redis_port,
-                password=hs.config.redis.redis_password,
-                reconnect=True,
-            )
+            outbound_redis_connection = hs.get_outbound_redis_connection()
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
         ctx_name = "replication-conn-%s" % self.conn_id
-        self._logging_context = BackgroundProcessLoggingContext(ctx_name)
-        self._logging_context.request = ctx_name
+        self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
 
 import txredisapi
 
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import (
     BackgroundProcessLoggingContext,
     run_as_background_process,
+    wrap_as_background_process,
 )
 from synapse.replication.tcp.commands import (
     Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     immediately after initialisation.
 
     Attributes:
-        handler: The command handler to handle incoming commands.
-        stream_name: The *redis* stream name to subscribe to and publish from
-            (not anything to do with Synapse replication streams).
-        outbound_redis_connection: The connection to redis to use to send
+        synapse_handler: The command handler to handle incoming commands.
+        synapse_stream_name: The *redis* stream name to subscribe to and publish
+            from (not anything to do with Synapse replication streams).
+        synapse_outbound_redis_connection: The connection to redis to use to send
             commands.
     """
 
-    handler = None  # type: ReplicationCommandHandler
-    stream_name = None  # type: str
-    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+    synapse_handler = None  # type: ReplicationCommandHandler
+    synapse_stream_name = None  # type: str
+    synapse_outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
-        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
-        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+        await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
         logger.info(
             "Successfully subscribed to redis stream, sending REPLICATE command"
         )
-        self.handler.new_connection(self)
+        self.synapse_handler.new_connection(self)
         await self._async_send_command(ReplicateCommand())
         logger.info("REPLICATE successfully sent")
 
         # We send out our positions when there is a new connection in case the
         # other side missed updates. We do this for Redis connections as the
         # otherside won't know we've connected and so won't issue a REPLICATE.
-        self.handler.send_positions_to_connection(self)
+        self.synapse_handler.send_positions_to_connection(self)
 
     def messageReceived(self, pattern: str, channel: str, message: str):
         """Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             cmd: received command
         """
 
-        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
         if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
             return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
-        self.handler.lost_connection(self)
+        self.synapse_handler.lost_connection(self)
 
         # mark the logging context as finished
         self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
         await make_deferred_yieldable(
-            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+            self.synapse_outbound_redis_connection.publish(
+                self.synapse_stream_name, encoded_string
+            )
+        )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+    """A subclass of RedisFactory that periodically sends pings to ensure that
+    we detect dead connections.
+    """
+
+    def __init__(
+        self,
+        hs: "HomeServer",
+        uuid: str,
+        dbid: Optional[int],
+        poolsize: int,
+        isLazy: bool = False,
+        handler: Type = txredisapi.ConnectionHandler,
+        charset: str = "utf-8",
+        password: Optional[str] = None,
+        replyTimeout: int = 30,
+        convertNumbers: Optional[int] = True,
+    ):
+        super().__init__(
+            uuid=uuid,
+            dbid=dbid,
+            poolsize=poolsize,
+            isLazy=isLazy,
+            handler=handler,
+            charset=charset,
+            password=password,
+            replyTimeout=replyTimeout,
+            convertNumbers=convertNumbers,
         )
 
+        hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+    @wrap_as_background_process("redis_ping")
+    async def _send_ping(self):
+        for connection in self.pool:
+            try:
+                await make_deferred_yieldable(connection.ping())
+            except Exception:
+                logger.warning("Failed to send ping to a redis connection")
 
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
     """This is a reconnecting factory that connects to redis and immediately
     subscribes to a stream.
 
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
     ):
 
-        super().__init__()
-
-        # This sets the password on the RedisFactory base class (as
-        # SubscriberFactory constructor doesn't pass it through).
-        self.password = hs.config.redis.redis_password
+        super().__init__(
+            hs,
+            uuid="subscriber",
+            dbid=None,
+            poolsize=1,
+            replyTimeout=30,
+            password=hs.config.redis.redis_password,
+        )
 
-        self.handler = hs.get_tcp_replication()
-        self.stream_name = hs.hostname
+        self.synapse_handler = hs.get_tcp_replication()
+        self.synapse_stream_name = hs.hostname
 
-        self.outbound_redis_connection = outbound_redis_connection
+        self.synapse_outbound_redis_connection = outbound_redis_connection
 
     def buildProtocol(self, addr):
-        p = super().buildProtocol(addr)  # type: RedisSubscriber
+        p = super().buildProtocol(addr)
+        p = cast(RedisSubscriber, p)
 
         # We do this here rather than add to the constructor of `RedisSubcriber`
         # as to do so would involve overriding `buildProtocol` entirely, however
         # the base method does some other things than just instantiating the
         # protocol.
-        p.handler = self.handler
-        p.outbound_redis_connection = self.outbound_redis_connection
-        p.stream_name = self.stream_name
-        p.password = self.password
+        p.synapse_handler = self.synapse_handler
+        p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+        p.synapse_stream_name = self.synapse_stream_name
 
         return p
 
 
 def lazyConnection(
-    reactor,
+    hs: "HomeServer",
     host: str = "localhost",
     port: int = 6379,
     dbid: Optional[int] = None,
     reconnect: bool = True,
-    charset: str = "utf-8",
     password: Optional[str] = None,
-    connectTimeout: Optional[int] = None,
-    replyTimeout: Optional[int] = None,
-    convertNumbers: bool = True,
+    replyTimeout: int = 30,
 ) -> txredisapi.RedisProtocol:
-    """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
-    reactor.
+    """Creates a connection to Redis that is lazily set up and reconnects if the
+    connections is lost.
     """
 
-    isLazy = True
-    poolsize = 1
-
     uuid = "%s:%d" % (host, port)
-    factory = txredisapi.RedisFactory(
-        uuid,
-        dbid,
-        poolsize,
-        isLazy,
-        txredisapi.ConnectionHandler,
-        charset,
-        password,
-        replyTimeout,
-        convertNumbers,
+    factory = SynapseRedisFactory(
+        hs,
+        uuid=uuid,
+        dbid=dbid,
+        poolsize=1,
+        isLazy=True,
+        handler=txredisapi.ConnectionHandler,
+        password=password,
+        replyTimeout=replyTimeout,
     )
     factory.continueTrying = reconnect
-    for x in range(poolsize):
-        reactor.connectTCP(host, port, factory, connectTimeout)
+
+    reactor = hs.get_reactor()
+    reactor.connectTCP(host, port, factory, 30)
 
     return factory.handler