diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdd6bef6b4..4afefc6b1d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,7 +16,7 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional
+from typing import Dict, Iterable, List, Optional, Set
from six import iteritems, itervalues
@@ -662,23 +662,16 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
- def get_auth_chain(self, event_ids):
- """Gets the full auth chain for a set of events (including rejected
- events).
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
- Includes the given event IDs in the result.
-
- Note that:
- 1. All events must be state events.
- 2. For v1 rooms this may not have the full auth chain in the
- presence of rejected events
-
- Args:
- event_ids (list): The event IDs of the events to fetch the auth
- chain for. Must be state events.
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ Deferred[Set[str]]: Set of event IDs.
"""
- return self.store.get_auth_chain_ids(event_ids, include_given=True)
+ return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 24b7c0faef..9bf98d06f2 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -69,9 +69,9 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets)
- needed_events = set(
+ needed_events = {
event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
- )
+ }
needed_event_count = len(needed_events)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
@@ -261,11 +261,11 @@ def _resolve_state_events(conflicted_state, auth_events):
def _resolve_auth_events(events, auth_events):
- reverse = [i for i in reversed(_ordered_events(events))]
+ reverse = list(reversed(_ordered_events(events)))
- auth_keys = set(
+ auth_keys = {
key for event in events for key in event_auth.auth_types_for_event(event)
- )
+ }
new_auth_events = {}
for key in auth_keys:
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 531018c6a5..18484e2fa6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -105,7 +105,7 @@ def resolve_events_with_store(
% (room_id, event.event_id, event.room_id,)
)
- full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+ full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
@@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
- common = set(itervalues(state_sets[0])).intersection(
- *(itervalues(s) for s in state_sets[1:])
- )
-
- auth_sets = []
- for state_set in state_sets:
- auth_ids = set(
- eid
- for key, eid in iteritems(state_set)
- if (
- key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
- or key
- in (
- (EventTypes.PowerLevels, ""),
- (EventTypes.Create, ""),
- (EventTypes.JoinRules, ""),
- )
- )
- and eid not in common
- )
-
- auth_chain = yield state_res_store.get_auth_chain(auth_ids)
- auth_ids.update(auth_chain)
- auth_sets.append(auth_ids)
-
- intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
- union = set().union(*auth_sets)
+ difference = yield state_res_store.get_auth_chain_difference(
+ [set(state_set.values()) for state_set in state_sets]
+ )
- return union - intersection
+ return difference
def _seperate(state_sets):
@@ -275,7 +251,7 @@ def _seperate(state_sets):
conflicted_state = {}
for key in set(itertools.chain.from_iterable(state_sets)):
- event_ids = set(state_set.get(key) for state_set in state_sets)
+ event_ids = {state_set.get(key) for state_set in state_sets}
if len(event_ids) == 1:
unconflicted_state[key] = event_ids.pop()
else:
|