summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py8
-rw-r--r--synapse/state.py22
2 files changed, 12 insertions, 18 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 96963d7434..4873cf9d1f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -28,6 +28,12 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+AuthEventTypes = (
+    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+    EventTypes.JoinRules,
+)
+
+
 class Auth(object):
 
     def __init__(self, hs):
@@ -427,7 +433,7 @@ class Auth(object):
         context.auth_events = {
             k: v
             for k, v in context.current_state.items()
-            if v.event_id in auth_ids
+            if v.type in AuthEventTypes
         }
 
     def compute_auth_events(self, event, current_state):
diff --git a/synapse/state.py b/synapse/state.py
index 80cced351d..345046cd88 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
 from synapse.util.expiringcache import ExpiringCache
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
+from synapse.api.auth import AuthEventTypes
 from synapse.events.snapshot import EventContext
 
 from collections import namedtuple
@@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-AuthEventTypes = (
-    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
-    EventTypes.JoinRules,
-)
-
-
 SIZE_OF_CACHE = 1000
 EVICTION_TIMEOUT_SECONDS = 20
 
@@ -187,17 +182,10 @@ class StateHandler(object):
                 replaces = context.current_state[key]
                 event.unsigned["replaces_state"] = replaces.event_id
 
-        if hasattr(event, "auth_events") and event.auth_events:
-            auth_ids = self.hs.get_auth().compute_auth_events(
-                event, context.current_state
-            )
-            context.auth_events = {
-                k: v
-                for k, v in context.current_state.items()
-                if v.event_id in auth_ids
-            }
-        else:
-            context.auth_events = {}
+        context.auth_events = {
+            k: e for k, e in context.current_state.items()
+            if k[0] in AuthEventTypes
+        }
 
         context.prev_state_events = prev_state
         defer.returnValue(context)