From c5b6abd53d1549c7a7cc30c644dd8921bfc10ea2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Dec 2020 15:22:37 +0000 Subject: Correctly handle unpersisted events when calculating auth chain difference. (#8827) We do state res with unpersisted events when calculating the new current state of the room, so that should be the only thing impacted. I don't think this is tooooo big of a deal as: 1. the next time a state event happens in the room the current state should correct itself; 2. in the common case all the unpersisted events' auth events will be pulled in by other state, so will still return the correct result (or one which is sufficiently close to not affect the result); and 3. we mostly use the state at an event to do important operations, which isn't affected by this. --- synapse/state/v2.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 4 deletions(-) (limited to 'synapse/state') diff --git a/synapse/state/v2.py b/synapse/state/v2.py index f57df0d728..ffc504ce77 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -252,9 +252,88 @@ async def _get_auth_chain_difference( Set of event IDs """ - difference = await state_res_store.get_auth_chain_difference( - [set(state_set.values()) for state_set in state_sets] - ) + # The `StateResolutionStore.get_auth_chain_difference` function assumes that + # all events passed to it (and their auth chains) have been persisted + # previously. This is not the case for any events in the `event_map`, and so + # we need to manually handle those events. + # + # We do this by: + # 1. calculating the auth chain difference for the state sets based on the + # events in `event_map` alone + # 2. replacing any events in the state_sets that are also in `event_map` + # with their auth events (recursively), and then calling + # `store.get_auth_chain_difference` as normal + # 3. adding the results of 1 and 2 together. + + # Map from event ID in `event_map` to their auth event IDs, and their auth + # event IDs if they appear in the `event_map`. This is the intersection of + # the event's auth chain with the events in the `event_map` *plus* their + # auth event IDs. + events_to_auth_chain = {} # type: Dict[str, Set[str]] + for event in event_map.values(): + chain = {event.event_id} + events_to_auth_chain[event.event_id] = chain + + to_search = [event] + while to_search: + for auth_id in to_search.pop().auth_event_ids(): + chain.add(auth_id) + auth_event = event_map.get(auth_id) + if auth_event: + to_search.append(auth_event) + + # We now a) calculate the auth chain difference for the unpersisted events + # and b) work out the state sets to pass to the store. + # + # Note: If the `event_map` is empty (which is the common case), we can do a + # much simpler calculation. + if event_map: + # The list of state sets to pass to the store, where each state set is a set + # of the event ids making up the state. This is similar to `state_sets`, + # except that (a) we only have event ids, not the complete + # ((type, state_key)->event_id) mappings; and (b) we have stripped out + # unpersisted events and replaced them with the persisted events in + # their auth chain. + state_sets_ids = [] # type: List[Set[str]] + + # For each state set, the unpersisted event IDs reachable (by their auth + # chain) from the events in that set. + unpersisted_set_ids = [] # type: List[Set[str]] + + for state_set in state_sets: + set_ids = set() # type: Set[str] + state_sets_ids.append(set_ids) + + unpersisted_ids = set() # type: Set[str] + unpersisted_set_ids.append(unpersisted_ids) + + for event_id in state_set.values(): + event_chain = events_to_auth_chain.get(event_id) + if event_chain is not None: + # We have an event in `event_map`. We add all the auth + # events that it references (that aren't also in `event_map`). + set_ids.update(e for e in event_chain if e not in event_map) + + # We also add the full chain of unpersisted event IDs + # referenced by this state set, so that we can work out the + # auth chain difference of the unpersisted events. + unpersisted_ids.update(e for e in event_chain if e in event_map) + else: + set_ids.add(event_id) + + # The auth chain difference of the unpersisted events of the state sets + # is calculated by taking the difference between the union and + # intersections. + union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) + intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) + + difference_from_event_map = union - intersection # type: Collection[str] + else: + difference_from_event_map = () + state_sets_ids = [set(state_set.values()) for state_set in state_sets] + + difference = await state_res_store.get_auth_chain_difference(state_sets_ids) + difference.update(difference_from_event_map) return difference -- cgit 1.5.1 From df4b1e9c74d56d79c274149b0dfb0fd5305c7659 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Dec 2020 15:52:49 +0000 Subject: Pass room_id to get_auth_chain_difference (#8879) This is so that we can choose which algorithm to use based on the room ID. --- changelog.d/8879.misc | 1 + synapse/state/__init__.py | 4 ++-- synapse/state/v2.py | 9 +++++++-- synapse/storage/databases/main/event_federation.py | 4 +++- tests/state/test_v2.py | 14 ++++++++++---- tests/storage/test_event_federation.py | 18 ++++++++++-------- 6 files changed, 33 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8879.misc (limited to 'synapse/state') diff --git a/changelog.d/8879.misc b/changelog.d/8879.misc new file mode 100644 index 0000000000..6f9516b314 --- /dev/null +++ b/changelog.d/8879.misc @@ -0,0 +1 @@ +Pass `room_id` to `get_auth_chain_difference`. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1fa3b280b4..84f59c7d85 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -783,7 +783,7 @@ class StateResolutionStore: ) def get_auth_chain_difference( - self, state_sets: List[Set[str]] + self, room_id: str, state_sets: List[Set[str]] ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -796,4 +796,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(state_sets) + return self.store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index ffc504ce77..f85124bf81 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -97,7 +97,9 @@ async def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference( + room_id, state_sets, event_map, state_res_store + ) full_conflicted_set = set( itertools.chain( @@ -236,6 +238,7 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( + room_id: str, state_sets: Sequence[StateMap[str]], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", @@ -332,7 +335,9 @@ async def _get_auth_chain_difference( difference_from_event_map = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] - difference = await state_res_store.get_auth_chain_difference(state_sets_ids) + difference = await state_res_store.get_auth_chain_difference( + room_id, state_sets_ids + ) difference.update(difference_from_event_map) return difference diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 2e07c37340..ebffd89251 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: + async def get_auth_chain_difference( + self, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f5c6db900d..09f4f32a02 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -623,7 +623,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {c.event_id}) @@ -662,7 +664,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {d.event_id, c.event_id}) @@ -707,7 +711,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {d.event_id, e.event_id}) @@ -773,7 +779,7 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, auth_sets): + def get_auth_chain_difference(self, room_id, auth_sets): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 71c21d8c75..482506d731 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -202,39 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Now actually test that various combinations give the right result: difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b", "c"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "d", "e"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) ) self.assertSetEqual(difference, {"a", "b"}) - difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) self.assertSetEqual(difference, set()) -- cgit 1.5.1 From 5e7d75daa2d53ea74c75f9a0db52c5590fc22038 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 18 Dec 2020 15:00:34 +0000 Subject: Fix mainline ordering in state res v2 (#8971) This had two effects 1) it'd give the wrong answer and b) would iterate *all* power levels in the auth chain of each event. The latter of which can be *very* expensive for certain types of IRC bridge rooms that have large numbers of power level changes. --- changelog.d/8971.bugfix | 1 + synapse/state/v2.py | 2 +- tests/state/test_v2.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8971.bugfix (limited to 'synapse/state') diff --git a/changelog.d/8971.bugfix b/changelog.d/8971.bugfix new file mode 100644 index 0000000000..c3e44b8c0b --- /dev/null +++ b/changelog.d/8971.bugfix @@ -0,0 +1 @@ +Fix small bug in v2 state resolution algorithm, which could also cause performance issues for rooms with large numbers of power levels. diff --git a/synapse/state/v2.py b/synapse/state/v2.py index f85124bf81..e585954bd8 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -658,7 +658,7 @@ async def _get_mainline_depth_for_event( # We do an iterative search, replacing `event with the power level in its # auth events (if any) while tmp_event: - depth = mainline_map.get(event.event_id) + depth = mainline_map.get(tmp_event.event_id) if depth is not None: return depth diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 09f4f32a02..77c72834f2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -88,7 +88,7 @@ class FakeEvent: event_dict = { "auth_events": [(a, {}) for a in auth_events], "prev_events": [(p, {}) for p in prev_events], - "event_id": self.node_id, + "event_id": self.event_id, "sender": self.sender, "type": self.type, "content": self.content, @@ -381,6 +381,61 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) + def test_mainline_sort(self): + """Tests that the mainline ordering works correctly. + """ + + events = [ + FakeEvent( + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": {ALICE: 100, BOB: 50}, + "events": {EventTypes.PowerLevels: 100}, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + ] + + edges = [ + ["END", "T3", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T4", "PB", "PA1"], + ] + + # We expect T3 to be picked as the other topics are pointing at older + # power levels. Note that without mainline ordering we'd pick T4 due to + # it being sent *after* T3. + expected_state_ids = ["T3", "PA2"] + + self.do_check(events, edges, expected_state_ids) + def do_check(self, events, edges, expected_state_ids): """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` -- cgit 1.5.1 From dd8da8c5f6ac525a7456437913a03f68d4504605 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Jan 2021 13:57:31 +0000 Subject: Precompute joined hosts and store in Redis (#9198) --- changelog.d/9198.misc | 1 + stubs/txredisapi.pyi | 12 +++- synapse/config/_base.pyi | 2 + synapse/federation/sender/__init__.py | 50 +++++++++----- synapse/handlers/federation.py | 5 ++ synapse/handlers/message.py | 42 ++++++++++++ synapse/replication/tcp/external_cache.py | 105 ++++++++++++++++++++++++++++++ synapse/replication/tcp/handler.py | 15 +---- synapse/server.py | 30 +++++++++ synapse/state/__init__.py | 11 +++- tests/replication/_base.py | 41 +++++++----- 11 files changed, 265 insertions(+), 49 deletions(-) create mode 100644 changelog.d/9198.misc create mode 100644 synapse/replication/tcp/external_cache.py (limited to 'synapse/state') diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc new file mode 100644 index 0000000000..a6cb77fbb2 --- /dev/null +++ b/changelog.d/9198.misc @@ -0,0 +1 @@ +Precompute joined hosts and store in Redis. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bdc892ec82..618548a305 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,11 +15,21 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Type, Union +from typing import Any, List, Optional, Type, Union class RedisProtocol: def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... + async def set( + self, + key: str, + value: Any, + expire: Optional[int] = None, + pexpire: Optional[int] = None, + only_if_not_exists: bool = False, + only_if_exists: bool = False, + ) -> None: ... + async def get(self, key: str) -> Any: ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..8ba669059a 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -18,6 +18,7 @@ from synapse.config import ( password_auth_providers, push, ratelimiting, + redis, registration, repository, room_directory, @@ -79,6 +80,7 @@ class RootConfig: roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig + redis: redis.RedisConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 604cfd1935..643b26ae6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -142,6 +142,8 @@ class FederationSender: self._wake_destinations_needing_catchup, ) + self._external_cache = hs.get_external_cache() + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -197,22 +199,40 @@ class FederationSender: if not event.internal_metadata.should_proactively_send(): return - try: - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() - ) - except Exception: - logger.exception( - "Failed to calculate hosts in room for event: %s", - event.event_id, + destinations = None # type: Optional[Set[str]] + if not event.prev_event_ids(): + # If there are no prev event IDs then the state is empty + # and so no remote servers in the room + destinations = set() + else: + # We check the external cache for the destinations, which is + # stored per state group. + + sg = await self._external_cache.get( + "event_to_prev_state_group", event.event_id ) - return + if sg: + destinations = await self._external_cache.get( + "get_joined_hosts", str(sg) + ) + + if destinations is None: + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return destinations = { d diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd8de8696d..b6dc7f99b6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dfeab09cd..e2a7d567fa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -432,6 +432,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -939,6 +941,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +982,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644 index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 58d46a5951..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -286,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -303,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - hs=hs, - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/server.py b/synapse/server.py index 9cdda83aa1..9bdd3177d7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.saml_handler import SamlHandler @@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_external_cache(self) -> ExternalCache: + return ExternalCache(self) + + @cache_in_self + def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: + if not self.config.redis.redis_enabled: + return None + + # We only want to import redis module if we're using it, as we have + # `txredisapi` as an optional dependency. + from synapse.replication.tcp.redis import lazyConnection + + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis_host, + self.config.redis_port, + ) + + return lazyConnection( + hs=self, + host=self.config.redis_host, + port=self.config.redis_port, + password=self.config.redis.redis_password, + reconnect=True, + ) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 84f59c7d85..3bd9ff8ca0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -310,6 +310,7 @@ class StateHandler: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None + entry = None else: # otherwise, we'll need to resolve the state across the prev_events. @@ -340,9 +341,13 @@ class StateHandler: current_state_ids=state_ids_before_event, ) - # XXX: can we update the state cache entry for the new state group? or - # could we set a flag on resolve_state_groups_for_events to tell it to - # always make a state group? + # Assign the new state group to the cached state entry. + # + # Note that this can race in that we could generate multiple state + # groups for the same state entry, but that is just inefficient + # rather than dangerous. + if entry and entry.state_group is None: + entry.state_group = state_group_before_event # # now if it's not a state event, we're done diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 3379189785..d5dce1f83f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Fake in memory Redis server that servers can connect to. self._redis_server = FakeRedisPubSubServer() + # We may have an attempt to connect to redis for the external cache already. + self.connect_any_redis_attempts() + store = self.hs.get_datastore() self.database_pool = store.db_pool @@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): fake one. """ clients = self.reactor.tcpClients - self.assertEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") - self.assertEqual(port, 6379) + while clients: + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) - client_to_server_transport = FakeTransport( - server_protocol, self.reactor, client_protocol - ) - client_protocol.makeConnection(client_to_server_transport) - - server_to_client_transport = FakeTransport( - client_protocol, self.reactor, server_protocol - ) - server_protocol.makeConnection(server_to_client_transport) + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) - return client_to_server_transport, server_to_client_transport + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) class TestReplicationDataHandler(GenericWorkerReplicationHandler): @@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol): (channel,) = args self._server.add_subscriber(self) self.send(["subscribe", channel, 1]) + + # Since we use SET/GET to cache things we can safely no-op them. + elif command == b"SET": + self.send("OK") + elif command == b"GET": + self.send(None) else: raise Exception("Unknown command") @@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol): # We assume bytes are just unicode strings. obj = obj.decode("utf-8") + if obj is None: + return "$-1\r\n" if isinstance(obj, str): return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) if isinstance(obj, int): -- cgit 1.5.1