diff --git a/synapse/state.py b/synapse/state.py
index 8a056ee955..695a5e7ac4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -37,7 +37,10 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
+AuthEventTypes = (
+ EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+ EventTypes.JoinRules,
+)
class StateHandler(object):
@@ -100,7 +103,9 @@ class StateHandler(object):
context.state_group = None
if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = zip(*event.auth_events)[0]
+ auth_ids = self.hs.get_auth().compute_auth_events(
+ event, context.current_state
+ )
context.auth_events = {
k: v
for k, v in context.current_state.items()
@@ -146,7 +151,9 @@ class StateHandler(object):
event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = zip(*event.auth_events)[0]
+ auth_ids = self.hs.get_auth().compute_auth_events(
+ event, context.current_state
+ )
context.auth_events = {
k: v
for k, v in context.current_state.items()
@@ -259,6 +266,15 @@ class StateHandler(object):
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
+ if key[0] == EventTypes.JoinRules:
+ resolved_state[key] = self._resolve_auth_events(
+ events,
+ auth_events
+ )
+
+ auth_events.update(resolved_state)
+
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
resolved_state[key] = self._resolve_auth_events(
events,
|