diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..8ba669059a 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -18,6 +18,7 @@ from synapse.config import (
password_auth_providers,
push,
ratelimiting,
+ redis,
registration,
repository,
room_directory,
@@ -79,6 +80,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
+ redis: redis.RedisConfig
config_classes: List = ...
def __init__(self) -> None: ...
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)
+ self._external_cache = hs.get_external_cache()
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return
- try:
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = await self.state.get_hosts_in_room_at_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
- except Exception:
- logger.exception(
- "Failed to calculate hosts in room for event: %s",
- event.event_id,
+ destinations = None # type: Optional[Set[str]]
+ if not event.prev_event_ids():
+ # If there are no prev event IDs then the state is empty
+ # and so no remote servers in the room
+ destinations = set()
+ else:
+ # We check the external cache for the destinations, which is
+ # stored per state group.
+
+ sg = await self._external_cache.get(
+ "event_to_prev_state_group", event.event_id
)
- return
+ if sg:
+ destinations = await self._external_cache.get(
+ "get_joined_hosts", str(sg)
+ )
+
+ if destinations is None:
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = await self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
destinations = {
d
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..b6dc7f99b6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
return context
async def _check_for_soft_fail(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9dfeab09cd..e2a7d567fa 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -432,6 +432,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._external_cache = hs.get_external_cache()
+
async def create_event(
self,
requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
+ await self.cache_joined_hosts_for_event(event)
+
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
+ async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+ """Precalculate the joined hosts at the event, when using Redis, so that
+ external federation senders don't have to recalculate it themselves.
+ """
+
+ if not self._external_cache.is_enabled():
+ return
+
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We always set the state group -> joined hosts cache, even if
+ # we already set it, so that the expiry time is reset.
+
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+
+ if state_entry.state_group:
+ joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
+
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+set_counter = Counter(
+ "synapse_external_cache_set",
+ "Number of times we set a cache",
+ labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+ "synapse_external_cache_get",
+ "Number of times we get a cache",
+ labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+ """A cache backed by an external Redis. Does nothing if no Redis is
+ configured.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._redis_connection = hs.get_outbound_redis_connection()
+
+ def _get_redis_key(self, cache_name: str, key: str) -> str:
+ return "cache_v1:%s:%s" % (cache_name, key)
+
+ def is_enabled(self) -> bool:
+ """Whether the external cache is used or not.
+
+ It's safe to use the cache when this returns false, the methods will
+ just no-op, but the function is useful to avoid doing unnecessary work.
+ """
+ return self._redis_connection is not None
+
+ async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+ """Add the key/value to the named cache, with the expiry time given.
+ """
+
+ if self._redis_connection is None:
+ return
+
+ set_counter.labels(cache_name).inc()
+
+ # txredisapi requires the value to be string, bytes or numbers, so we
+ # encode stuff in JSON.
+ encoded_value = json_encoder.encode(value)
+
+ logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+ return await make_deferred_yieldable(
+ self._redis_connection.set(
+ self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+ )
+ )
+
+ async def get(self, cache_name: str, key: str) -> Optional[Any]:
+ """Look up a key/value in the named cache.
+ """
+
+ if self._redis_connection is None:
+ return None
+
+ result = await make_deferred_yieldable(
+ self._redis_connection.get(self._get_redis_key(cache_name, key))
+ )
+
+ logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+ get_counter.labels(cache_name, result is not None).inc()
+
+ if not result:
+ return None
+
+ # For some reason the integers get magically converted back to integers
+ if isinstance(result, int):
+ return result
+
+ return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 58d46a5951..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -286,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
- lazyConnection,
- )
-
- logger.info(
- "Connecting to redis (host=%r port=%r)",
- hs.config.redis_host,
- hs.config.redis_port,
)
# First let's ensure that we have a ReplicationStreamer started.
@@ -303,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = lazyConnection(
- hs=hs,
- host=hs.config.redis_host,
- port=hs.config.redis_port,
- password=hs.config.redis.redis_password,
- reconnect=True,
- )
+ outbound_redis_connection = hs.get_outbound_redis_connection()
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..9bdd3177d7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from txredisapi import RedisProtocol
+
from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler
@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
+ @cache_in_self
+ def get_external_cache(self) -> ExternalCache:
+ return ExternalCache(self)
+
+ @cache_in_self
+ def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+ if not self.config.redis.redis_enabled:
+ return None
+
+ # We only want to import redis module if we're using it, as we have
+ # `txredisapi` as an optional dependency.
+ from synapse.replication.tcp.redis import lazyConnection
+
+ logger.info(
+ "Connecting to redis (host=%r port=%r) for external cache",
+ self.config.redis_host,
+ self.config.redis_port,
+ )
+
+ return lazyConnection(
+ hs=self,
+ host=self.config.redis_host,
+ port=self.config.redis_port,
+ password=self.config.redis.redis_password,
+ reconnect=True,
+ )
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
+ entry = None
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
- # XXX: can we update the state cache entry for the new state group? or
- # could we set a flag on resolve_state_groups_for_events to tell it to
- # always make a state group?
+ # Assign the new state group to the cached state entry.
+ #
+ # Note that this can race in that we could generate multiple state
+ # groups for the same state entry, but that is just inefficient
+ # rather than dangerous.
+ if entry and entry.state_group is None:
+ entry.state_group = state_group_before_event
#
# now if it's not a state event, we're done
|