summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9198.misc1
-rw-r--r--stubs/txredisapi.pyi12
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/federation/sender/__init__.py50
-rw-r--r--synapse/handlers/federation.py5
-rw-r--r--synapse/handlers/message.py42
-rw-r--r--synapse/replication/tcp/external_cache.py105
-rw-r--r--synapse/replication/tcp/handler.py15
-rw-r--r--synapse/server.py30
-rw-r--r--synapse/state/__init__.py11
-rw-r--r--tests/replication/_base.py41
11 files changed, 265 insertions, 49 deletions
diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc
new file mode 100644
index 0000000000..a6cb77fbb2
--- /dev/null
+++ b/changelog.d/9198.misc
@@ -0,0 +1 @@
+Precompute joined hosts and store in Redis.
diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi
index bdc892ec82..618548a305 100644
--- a/stubs/txredisapi.pyi
+++ b/stubs/txredisapi.pyi
@@ -15,11 +15,21 @@
 
 """Contains *incomplete* type hints for txredisapi.
 """
-from typing import List, Optional, Type, Union
+from typing import Any, List, Optional, Type, Union
 
 class RedisProtocol:
     def publish(self, channel: str, message: bytes): ...
     async def ping(self) -> None: ...
+    async def set(
+        self,
+        key: str,
+        value: Any,
+        expire: Optional[int] = None,
+        pexpire: Optional[int] = None,
+        only_if_not_exists: bool = False,
+        only_if_exists: bool = False,
+    ) -> None: ...
+    async def get(self, key: str) -> Any: ...
 
 class SubscriberProtocol(RedisProtocol):
     def __init__(self, *args, **kwargs): ...
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..8ba669059a 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -18,6 +18,7 @@ from synapse.config import (
     password_auth_providers,
     push,
     ratelimiting,
+    redis,
     registration,
     repository,
     room_directory,
@@ -79,6 +80,7 @@ class RootConfig:
     roomdirectory: room_directory.RoomDirectoryConfig
     thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
     tracer: tracer.TracerConfig
+    redis: redis.RedisConfig
 
     config_classes: List = ...
     def __init__(self) -> None: ...
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
             self._wake_destinations_needing_catchup,
         )
 
+        self._external_cache = hs.get_external_cache()
+
     def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
         """Get or create a PerDestinationQueue for the given destination
 
@@ -197,22 +199,40 @@ class FederationSender:
                     if not event.internal_metadata.should_proactively_send():
                         return
 
-                    try:
-                        # Get the state from before the event.
-                        # We need to make sure that this is the state from before
-                        # the event and not from after it.
-                        # Otherwise if the last member on a server in a room is
-                        # banned then it won't receive the event because it won't
-                        # be in the room after the ban.
-                        destinations = await self.state.get_hosts_in_room_at_events(
-                            event.room_id, event_ids=event.prev_event_ids()
-                        )
-                    except Exception:
-                        logger.exception(
-                            "Failed to calculate hosts in room for event: %s",
-                            event.event_id,
+                    destinations = None  # type: Optional[Set[str]]
+                    if not event.prev_event_ids():
+                        # If there are no prev event IDs then the state is empty
+                        # and so no remote servers in the room
+                        destinations = set()
+                    else:
+                        # We check the external cache for the destinations, which is
+                        # stored per state group.
+
+                        sg = await self._external_cache.get(
+                            "event_to_prev_state_group", event.event_id
                         )
-                        return
+                        if sg:
+                            destinations = await self._external_cache.get(
+                                "get_joined_hosts", str(sg)
+                            )
+
+                    if destinations is None:
+                        try:
+                            # Get the state from before the event.
+                            # We need to make sure that this is the state from before
+                            # the event and not from after it.
+                            # Otherwise if the last member on a server in a room is
+                            # banned then it won't receive the event because it won't
+                            # be in the room after the ban.
+                            destinations = await self.state.get_hosts_in_room_at_events(
+                                event.room_id, event_ids=event.prev_event_ids()
+                            )
+                        except Exception:
+                            logger.exception(
+                                "Failed to calculate hosts in room for event: %s",
+                                event.event_id,
+                            )
+                            return
 
                     destinations = {
                         d
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..b6dc7f99b6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.GuestAccess and not context.rejected:
             await self.maybe_kick_guest_users(event)
 
+        # If we are going to send this event over federation we precaclculate
+        # the joined hosts.
+        if event.internal_metadata.get_send_on_behalf_of():
+            await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
         return context
 
     async def _check_for_soft_fail(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..e2a7d567fa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -432,6 +432,8 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
+        self._external_cache = hs.get_external_cache()
+
     async def create_event(
         self,
         requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
 
         await self.action_generator.handle_push_actions_for_event(event, context)
 
+        await self.cache_joined_hosts_for_event(event)
+
         try:
             # If we're a worker we need to hit out to the master.
             writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
             await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
+    async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+        """Precalculate the joined hosts at the event, when using Redis, so that
+        external federation senders don't have to recalculate it themselves.
+        """
+
+        if not self._external_cache.is_enabled():
+            return
+
+        # We actually store two mappings, event ID -> prev state group,
+        # state group -> joined hosts, which is much more space efficient
+        # than event ID -> joined hosts.
+        #
+        # Note: We have to cache event ID -> prev state group, as we don't
+        # store that in the DB.
+        #
+        # Note: We always set the state group -> joined hosts cache, even if
+        # we already set it, so that the expiry time is reset.
+
+        state_entry = await self.state.resolve_state_groups_for_events(
+            event.room_id, event_ids=event.prev_event_ids()
+        )
+
+        if state_entry.state_group:
+            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+            await self._external_cache.set(
+                "event_to_prev_state_group",
+                event.event_id,
+                state_entry.state_group,
+                expiry_ms=60 * 60 * 1000,
+            )
+            await self._external_cache.set(
+                "get_joined_hosts",
+                str(state_entry.state_group),
+                list(joined_hosts),
+                expiry_ms=60 * 60 * 1000,
+            )
+
     async def _validate_canonical_alias(
         self, directory_handler, room_alias_str: str, expected_room_id: str
     ) -> None:
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+set_counter = Counter(
+    "synapse_external_cache_set",
+    "Number of times we set a cache",
+    labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+    "synapse_external_cache_get",
+    "Number of times we get a cache",
+    labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+    """A cache backed by an external Redis. Does nothing if no Redis is
+    configured.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._redis_connection = hs.get_outbound_redis_connection()
+
+    def _get_redis_key(self, cache_name: str, key: str) -> str:
+        return "cache_v1:%s:%s" % (cache_name, key)
+
+    def is_enabled(self) -> bool:
+        """Whether the external cache is used or not.
+
+        It's safe to use the cache when this returns false, the methods will
+        just no-op, but the function is useful to avoid doing unnecessary work.
+        """
+        return self._redis_connection is not None
+
+    async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+        """Add the key/value to the named cache, with the expiry time given.
+        """
+
+        if self._redis_connection is None:
+            return
+
+        set_counter.labels(cache_name).inc()
+
+        # txredisapi requires the value to be string, bytes or numbers, so we
+        # encode stuff in JSON.
+        encoded_value = json_encoder.encode(value)
+
+        logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+        return await make_deferred_yieldable(
+            self._redis_connection.set(
+                self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+            )
+        )
+
+    async def get(self, cache_name: str, key: str) -> Optional[Any]:
+        """Look up a key/value in the named cache.
+        """
+
+        if self._redis_connection is None:
+            return None
+
+        result = await make_deferred_yieldable(
+            self._redis_connection.get(self._get_redis_key(cache_name, key))
+        )
+
+        logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+        get_counter.labels(cache_name, result is not None).inc()
+
+        if not result:
+            return None
+
+        # For some reason the integers get magically converted back to integers
+        if isinstance(result, int):
+            return result
+
+        return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 58d46a5951..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -286,13 +286,6 @@ class ReplicationCommandHandler:
         if hs.config.redis.redis_enabled:
             from synapse.replication.tcp.redis import (
                 RedisDirectTcpReplicationClientFactory,
-                lazyConnection,
-            )
-
-            logger.info(
-                "Connecting to redis (host=%r port=%r)",
-                hs.config.redis_host,
-                hs.config.redis_port,
             )
 
             # First let's ensure that we have a ReplicationStreamer started.
@@ -303,13 +296,7 @@ class ReplicationCommandHandler:
             # connection after SUBSCRIBE is called).
 
             # First create the connection for sending commands.
-            outbound_redis_connection = lazyConnection(
-                hs=hs,
-                host=hs.config.redis_host,
-                port=hs.config.redis_port,
-                password=hs.config.redis.redis_password,
-                reconnect=True,
-            )
+            outbound_redis_connection = hs.get_outbound_redis_connection()
 
             # Now create the factory/connection for the subscription stream.
             self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..9bdd3177d7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
 logger = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
+    from txredisapi import RedisProtocol
+
     from synapse.handlers.oidc_handler import OidcHandler
     from synapse.handlers.saml_handler import SamlHandler
 
@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_account_data_handler(self) -> AccountDataHandler:
         return AccountDataHandler(self)
 
+    @cache_in_self
+    def get_external_cache(self) -> ExternalCache:
+        return ExternalCache(self)
+
+    @cache_in_self
+    def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+        if not self.config.redis.redis_enabled:
+            return None
+
+        # We only want to import redis module if we're using it, as we have
+        # `txredisapi` as an optional dependency.
+        from synapse.replication.tcp.redis import lazyConnection
+
+        logger.info(
+            "Connecting to redis (host=%r port=%r) for external cache",
+            self.config.redis_host,
+            self.config.redis_port,
+        )
+
+        return lazyConnection(
+            hs=self,
+            host=self.config.redis_host,
+            port=self.config.redis_port,
+            password=self.config.redis.redis_password,
+            reconnect=True,
+        )
+
     async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
         return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
+            entry = None
 
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
                 current_state_ids=state_ids_before_event,
             )
 
-            # XXX: can we update the state cache entry for the new state group? or
-            # could we set a flag on resolve_state_groups_for_events to tell it to
-            # always make a state group?
+            # Assign the new state group to the cached state entry.
+            #
+            # Note that this can race in that we could generate multiple state
+            # groups for the same state entry, but that is just inefficient
+            # rather than dangerous.
+            if entry and entry.state_group is None:
+                entry.state_group = state_group_before_event
 
         #
         # now if it's not a state event, we're done
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3379189785..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Fake in memory Redis server that servers can connect to.
         self._redis_server = FakeRedisPubSubServer()
 
+        # We may have an attempt to connect to redis for the external cache already.
+        self.connect_any_redis_attempts()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         fake one.
         """
         clients = self.reactor.tcpClients
-        self.assertEqual(len(clients), 1)
-        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, "localhost")
-        self.assertEqual(port, 6379)
+        while clients:
+            (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+            self.assertEqual(host, "localhost")
+            self.assertEqual(port, 6379)
 
-        client_protocol = client_factory.buildProtocol(None)
-        server_protocol = self._redis_server.buildProtocol(None)
+            client_protocol = client_factory.buildProtocol(None)
+            server_protocol = self._redis_server.buildProtocol(None)
 
-        client_to_server_transport = FakeTransport(
-            server_protocol, self.reactor, client_protocol
-        )
-        client_protocol.makeConnection(client_to_server_transport)
-
-        server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, server_protocol
-        )
-        server_protocol.makeConnection(server_to_client_transport)
+            client_to_server_transport = FakeTransport(
+                server_protocol, self.reactor, client_protocol
+            )
+            client_protocol.makeConnection(client_to_server_transport)
 
-        return client_to_server_transport, server_to_client_transport
+            server_to_client_transport = FakeTransport(
+                client_protocol, self.reactor, server_protocol
+            )
+            server_protocol.makeConnection(server_to_client_transport)
 
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
             (channel,) = args
             self._server.add_subscriber(self)
             self.send(["subscribe", channel, 1])
+
+        # Since we use SET/GET to cache things we can safely no-op them.
+        elif command == b"SET":
+            self.send("OK")
+        elif command == b"GET":
+            self.send(None)
         else:
             raise Exception("Unknown command")
 
@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
             # We assume bytes are just unicode strings.
             obj = obj.decode("utf-8")
 
+        if obj is None:
+            return "$-1\r\n"
         if isinstance(obj, str):
             return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
         if isinstance(obj, int):