diff --git a/changelog.d/6952.misc b/changelog.d/6952.misc
new file mode 100644
index 0000000000..e26dc5cab8
--- /dev/null
+++ b/changelog.d/6952.misc
@@ -0,0 +1 @@
+Improve perf of v2 state res for large rooms.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdd6bef6b4..df7a4f6a89 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional
+from typing import Dict, Iterable, List, Optional, Set
from six import iteritems, itervalues
@@ -662,7 +662,7 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
- def get_auth_chain(self, event_ids):
+ def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -674,11 +674,16 @@ class StateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
- chain for. Must be state events.
+ event_ids: The event IDs of the events to fetch the auth chain for.
+ Must be state events.
+ ignore_events: Set of events to exclude from the returned auth
+ chain.
+
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
- return self.store.get_auth_chain_ids(event_ids, include_given=True)
+ return self.store.get_auth_chain_ids(
+ event_ids, include_given=True, ignore_events=ignore_events,
+ )
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 531018c6a5..75fe58305a 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -248,7 +248,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
and eid not in common
)
- auth_chain = yield state_res_store.get_auth_chain(auth_ids)
+ auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
auth_ids.update(auth_chain)
auth_sets.append(auth_ids)
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 60c67457b4..e16da2577d 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
import logging
+from typing import List, Optional, Set
from six.moves import range
from six.moves.queue import Empty, PriorityQueue
@@ -46,21 +47,37 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
- def get_auth_chain_ids(self, event_ids, include_given=False):
+ def get_auth_chain_ids(
+ self,
+ event_ids: List[str],
+ include_given: bool = False,
+ ignore_events: Optional[Set[str]] = None,
+ ):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
+ ignore_events: Set of events to exclude from the returned auth
+ chain. This is useful if the caller will just discard the
+ given events anyway, and saves us from figuring out their auth
+ chains if not required.
Returns:
list of event_ids
"""
return self.db.runInteraction(
- "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+ "get_auth_chain_ids",
+ self._get_auth_chain_ids_txn,
+ event_ids,
+ include_given,
+ ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+ if ignore_events is None:
+ ignore_events = set()
+
if include_given:
results = set(event_ids)
else:
@@ -80,6 +97,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(base_sql + clause, list(args))
new_front.update([r[0] for r in txn])
+ new_front -= ignore_events
new_front -= results
front = new_front
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 5bafad9f19..5059ade850 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
- def get_auth_chain(self, event_ids):
+ def get_auth_chain(self, event_ids, ignore_events):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -617,6 +617,8 @@ class TestStateResolutionStore(object):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
+ ignore_events: Set of events to exclude from the returned auth
+ chain.
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
@@ -627,7 +629,7 @@ class TestStateResolutionStore(object):
stack = list(event_ids)
while stack:
event_id = stack.pop()
- if event_id in result:
+ if event_id in result or event_id in ignore_events:
continue
result.add(event_id)
|