diff options
author | Erik Johnston <erik@matrix.org> | 2020-03-18 16:46:41 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-18 16:46:41 +0000 |
commit | 4a17a647a9508b70de35130fd82e3e21474270a9 (patch) | |
tree | 5dae0bdea89f8639d6990854913fd81bfd9755ab /tests/state/test_v2.py | |
parent | Add an option to the set password API to choose whether to logout other devic... (diff) | |
download | synapse-4a17a647a9508b70de35130fd82e3e21474270a9.tar.xz |
Improve get auth chain difference algorithm. (#7095)
It was originally implemented by pulling the full auth chain of all state sets out of the database and doing set comparison. However, that can take a lot work if the state and auth chains are large. Instead, lets try and fetch the auth chains at the same time and calculate the difference on the fly, allowing us to bail early if all the auth chains converge. Assuming that the auth chains do converge more often than not, this should improve performance. Hopefully.
Diffstat (limited to 'tests/state/test_v2.py')
-rw-r--r-- | tests/state/test_v2.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5059ade850..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -603,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids, ignore_events): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -617,9 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - ignore_events: Set of events to exclude from the returned auth - chain. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -629,7 +626,7 @@ class TestStateResolutionStore(object): stack = list(event_ids) while stack: event_id = stack.pop() - if event_id in result or event_id in ignore_events: + if event_id in result: continue result.add(event_id) @@ -639,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common |