summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r--synapse/state/v2.py36
1 files changed, 6 insertions, 30 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 531018c6a5..18484e2fa6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -105,7 +105,7 @@ def resolve_events_with_store(
                 % (room_id, event.event_id, event.room_id,)
             )
 
-    full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+    full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
 
     logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
 
@@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
     Returns:
         Deferred[set[str]]: Set of event IDs
     """
-    common = set(itervalues(state_sets[0])).intersection(
-        *(itervalues(s) for s in state_sets[1:])
-    )
-
-    auth_sets = []
-    for state_set in state_sets:
-        auth_ids = set(
-            eid
-            for key, eid in iteritems(state_set)
-            if (
-                key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
-                or key
-                in (
-                    (EventTypes.PowerLevels, ""),
-                    (EventTypes.Create, ""),
-                    (EventTypes.JoinRules, ""),
-                )
-            )
-            and eid not in common
-        )
-
-        auth_chain = yield state_res_store.get_auth_chain(auth_ids)
-        auth_ids.update(auth_chain)
 
-        auth_sets.append(auth_ids)
-
-    intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
-    union = set().union(*auth_sets)
+    difference = yield state_res_store.get_auth_chain_difference(
+        [set(state_set.values()) for state_set in state_sets]
+    )
 
-    return union - intersection
+    return difference
 
 
 def _seperate(state_sets):
@@ -275,7 +251,7 @@ def _seperate(state_sets):
     conflicted_state = {}
 
     for key in set(itertools.chain.from_iterable(state_sets)):
-        event_ids = set(state_set.get(key) for state_set in state_sets)
+        event_ids = {state_set.get(key) for state_set in state_sets}
         if len(event_ids) == 1:
             unconflicted_state[key] = event_ids.pop()
         else: