From c5b6abd53d1549c7a7cc30c644dd8921bfc10ea2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Dec 2020 15:22:37 +0000 Subject: Correctly handle unpersisted events when calculating auth chain difference. (#8827) We do state res with unpersisted events when calculating the new current state of the room, so that should be the only thing impacted. I don't think this is tooooo big of a deal as: 1. the next time a state event happens in the room the current state should correct itself; 2. in the common case all the unpersisted events' auth events will be pulled in by other state, so will still return the correct result (or one which is sufficiently close to not affect the result); and 3. we mostly use the state at an event to do important operations, which isn't affected by this. --- synapse/state/v2.py | 87 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 83 insertions(+), 4 deletions(-) (limited to 'synapse/state/v2.py') diff --git a/synapse/state/v2.py b/synapse/state/v2.py index f57df0d728..ffc504ce77 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase -from synapse.types import MutableStateMap, StateMap +from synapse.types import Collection, MutableStateMap, StateMap from synapse.util import Clock logger = logging.getLogger(__name__) @@ -252,9 +252,88 @@ async def _get_auth_chain_difference( Set of event IDs """ - difference = await state_res_store.get_auth_chain_difference( - [set(state_set.values()) for state_set in state_sets] - ) + # The `StateResolutionStore.get_auth_chain_difference` function assumes that + # all events passed to it (and their auth chains) have been persisted + # previously. This is not the case for any events in the `event_map`, and so + # we need to manually handle those events. + # + # We do this by: + # 1. calculating the auth chain difference for the state sets based on the + # events in `event_map` alone + # 2. replacing any events in the state_sets that are also in `event_map` + # with their auth events (recursively), and then calling + # `store.get_auth_chain_difference` as normal + # 3. adding the results of 1 and 2 together. + + # Map from event ID in `event_map` to their auth event IDs, and their auth + # event IDs if they appear in the `event_map`. This is the intersection of + # the event's auth chain with the events in the `event_map` *plus* their + # auth event IDs. + events_to_auth_chain = {} # type: Dict[str, Set[str]] + for event in event_map.values(): + chain = {event.event_id} + events_to_auth_chain[event.event_id] = chain + + to_search = [event] + while to_search: + for auth_id in to_search.pop().auth_event_ids(): + chain.add(auth_id) + auth_event = event_map.get(auth_id) + if auth_event: + to_search.append(auth_event) + + # We now a) calculate the auth chain difference for the unpersisted events + # and b) work out the state sets to pass to the store. + # + # Note: If the `event_map` is empty (which is the common case), we can do a + # much simpler calculation. + if event_map: + # The list of state sets to pass to the store, where each state set is a set + # of the event ids making up the state. This is similar to `state_sets`, + # except that (a) we only have event ids, not the complete + # ((type, state_key)->event_id) mappings; and (b) we have stripped out + # unpersisted events and replaced them with the persisted events in + # their auth chain. + state_sets_ids = [] # type: List[Set[str]] + + # For each state set, the unpersisted event IDs reachable (by their auth + # chain) from the events in that set. + unpersisted_set_ids = [] # type: List[Set[str]] + + for state_set in state_sets: + set_ids = set() # type: Set[str] + state_sets_ids.append(set_ids) + + unpersisted_ids = set() # type: Set[str] + unpersisted_set_ids.append(unpersisted_ids) + + for event_id in state_set.values(): + event_chain = events_to_auth_chain.get(event_id) + if event_chain is not None: + # We have an event in `event_map`. We add all the auth + # events that it references (that aren't also in `event_map`). + set_ids.update(e for e in event_chain if e not in event_map) + + # We also add the full chain of unpersisted event IDs + # referenced by this state set, so that we can work out the + # auth chain difference of the unpersisted events. + unpersisted_ids.update(e for e in event_chain if e in event_map) + else: + set_ids.add(event_id) + + # The auth chain difference of the unpersisted events of the state sets + # is calculated by taking the difference between the union and + # intersections. + union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) + intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) + + difference_from_event_map = union - intersection # type: Collection[str] + else: + difference_from_event_map = () + state_sets_ids = [set(state_set.values()) for state_set in state_sets] + + difference = await state_res_store.get_auth_chain_difference(state_sets_ids) + difference.update(difference_from_event_map) return difference -- cgit 1.5.1 From df4b1e9c74d56d79c274149b0dfb0fd5305c7659 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 4 Dec 2020 15:52:49 +0000 Subject: Pass room_id to get_auth_chain_difference (#8879) This is so that we can choose which algorithm to use based on the room ID. --- changelog.d/8879.misc | 1 + synapse/state/__init__.py | 4 ++-- synapse/state/v2.py | 9 +++++++-- synapse/storage/databases/main/event_federation.py | 4 +++- tests/state/test_v2.py | 14 ++++++++++---- tests/storage/test_event_federation.py | 18 ++++++++++-------- 6 files changed, 33 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8879.misc (limited to 'synapse/state/v2.py') diff --git a/changelog.d/8879.misc b/changelog.d/8879.misc new file mode 100644 index 0000000000..6f9516b314 --- /dev/null +++ b/changelog.d/8879.misc @@ -0,0 +1 @@ +Pass `room_id` to `get_auth_chain_difference`. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 1fa3b280b4..84f59c7d85 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -783,7 +783,7 @@ class StateResolutionStore: ) def get_auth_chain_difference( - self, state_sets: List[Set[str]] + self, room_id: str, state_sets: List[Set[str]] ) -> Awaitable[Set[str]]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -796,4 +796,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(state_sets) + return self.store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index ffc504ce77..f85124bf81 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -97,7 +97,9 @@ async def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference( + room_id, state_sets, event_map, state_res_store + ) full_conflicted_set = set( itertools.chain( @@ -236,6 +238,7 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( + room_id: str, state_sets: Sequence[StateMap[str]], event_map: Dict[str, EventBase], state_res_store: "synapse.state.StateResolutionStore", @@ -332,7 +335,9 @@ async def _get_auth_chain_difference( difference_from_event_map = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] - difference = await state_res_store.get_auth_chain_difference(state_sets_ids) + difference = await state_res_store.get_auth_chain_difference( + room_id, state_sets_ids + ) difference.update(difference_from_event_map) return difference diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 2e07c37340..ebffd89251 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas return list(results) - async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]: + async def get_auth_chain_difference( + self, room_id: str, state_sets: List[Set[str]] + ) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index f5c6db900d..09f4f32a02 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -623,7 +623,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {c.event_id}) @@ -662,7 +664,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {d.event_id, c.event_id}) @@ -707,7 +711,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) - diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store) + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) self.assertEqual(difference, {d.event_id, e.event_id}) @@ -773,7 +779,7 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, auth_sets): + def get_auth_chain_difference(self, room_id, auth_sets): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 71c21d8c75..482506d731 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -202,39 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # Now actually test that various combinations give the right result: difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}]) ) self.assertSetEqual(difference, {"a", "b", "c"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}]) + self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}]) ) self.assertSetEqual(difference, {"a", "b"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "d", "e"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}]) ) self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"}) difference = self.get_success( - self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}]) + self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}]) ) self.assertSetEqual(difference, {"a", "b"}) - difference = self.get_success(self.store.get_auth_chain_difference([{"a"}])) + difference = self.get_success( + self.store.get_auth_chain_difference(room_id, [{"a"}]) + ) self.assertSetEqual(difference, set()) -- cgit 1.5.1