summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py34
-rw-r--r--synapse/events/snapshot.py3
-rw-r--r--synapse/handlers/_base.py6
-rw-r--r--synapse/handlers/federation.py8
-rw-r--r--synapse/state.py31
-rw-r--r--synapse/storage/state.py2
6 files changed, 32 insertions, 52 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b176db8ce1..90f9eb6847 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -28,6 +28,12 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+AuthEventTypes = (
+    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+    EventTypes.JoinRules,
+)
+
+
 class Auth(object):
 
     def __init__(self, hs):
@@ -166,6 +172,7 @@ class Auth(object):
         target = auth_events.get(key)
 
         target_in_room = target and target.membership == Membership.JOIN
+        target_banned = target and target.membership == Membership.BAN
 
         key = (EventTypes.JoinRules, "", )
         join_rule_event = auth_events.get(key)
@@ -194,6 +201,7 @@ class Auth(object):
             {
                 "caller_in_room": caller_in_room,
                 "caller_invited": caller_invited,
+                "target_banned": target_banned,
                 "target_in_room": target_in_room,
                 "membership": membership,
                 "join_rule": join_rule,
@@ -202,6 +210,11 @@ class Auth(object):
             }
         )
 
+        if ban_level:
+            ban_level = int(ban_level)
+        else:
+            ban_level = 50  # FIXME (erikj): What should we do here?
+
         if Membership.INVITE == membership:
             # TODO (erikj): We should probably handle this more intelligently
             # PRIVATE join rules.
@@ -212,6 +225,10 @@ class Auth(object):
                     403,
                     "%s not in room %s." % (event.user_id, event.room_id,)
                 )
+            elif target_banned:
+                raise AuthError(
+                    403, "%s is banned from the room" % (target_user_id,)
+                )
             elif target_in_room:  # the target is already in the room.
                 raise AuthError(403, "%s is already in the room." %
                                      target_user_id)
@@ -221,6 +238,8 @@ class Auth(object):
             # joined: It's a NOOP
             if event.user_id != target_user_id:
                 raise AuthError(403, "Cannot force another user to join.")
+            elif target_banned:
+                raise AuthError(403, "You are banned from this room")
             elif join_rule == JoinRules.PUBLIC:
                 pass
             elif join_rule == JoinRules.INVITE:
@@ -238,6 +257,10 @@ class Auth(object):
                     403,
                     "%s not in room %s." % (target_user_id, event.room_id,)
                 )
+            elif target_banned and user_level < ban_level:
+                raise AuthError(
+                    403, "You cannot unban user &s." % (target_user_id,)
+                )
             elif target_user_id != event.user_id:
                 if kick_level:
                     kick_level = int(kick_level)
@@ -249,11 +272,6 @@ class Auth(object):
                         403, "You cannot kick user %s." % target_user_id
                     )
         elif Membership.BAN == membership:
-            if ban_level:
-                ban_level = int(ban_level)
-            else:
-                ban_level = 50  # FIXME (erikj): What should we do here?
-
             if user_level < ban_level:
                 raise AuthError(403, "You don't have permission to ban")
         else:
@@ -412,12 +430,6 @@ class Auth(object):
 
         builder.auth_events = auth_events_entries
 
-        context.auth_events = {
-            k: v
-            for k, v in context.current_state.items()
-            if v.event_id in auth_ids
-        }
-
     def compute_auth_events(self, event, current_state):
         if event.type == EventTypes.Create:
             return []
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7e98bdef28..4ecadf0879 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -16,8 +16,7 @@
 
 class EventContext(object):
 
-    def __init__(self, current_state=None, auth_events=None):
+    def __init__(self, current_state=None):
         self.current_state = current_state
-        self.auth_events = auth_events
         self.state_group = None
         self.rejected = False
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 7f07f05215..48816a242d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -90,8 +90,8 @@ class BaseHandler(object):
         event = builder.build()
 
         logger.debug(
-            "Created event %s with auth_events: %s, current state: %s",
-            event.event_id, context.auth_events, context.current_state,
+            "Created event %s with current state: %s",
+            event.event_id, context.current_state,
         )
 
         defer.returnValue(
@@ -106,7 +106,7 @@ class BaseHandler(object):
         # We now need to go and hit out to wherever we need to hit out to.
 
         if not suppress_auth:
-            self.auth.check(event, auth_events=context.auth_events)
+            self.auth.check(event, auth_events=context.current_state)
 
         yield self.store.persist_event(event, context=context)
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ae4e9b316d..65cfacba2e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -464,11 +464,9 @@ class FederationHandler(BaseHandler):
             builder=builder,
         )
 
-        self.auth.check(event, auth_events=context.auth_events)
+        self.auth.check(event, auth_events=context.current_state)
 
-        pdu = event
-
-        defer.returnValue(pdu)
+        defer.returnValue(event)
 
     @defer.inlineCallbacks
     @log_function
@@ -705,7 +703,7 @@ class FederationHandler(BaseHandler):
         )
 
         if not auth_events:
-            auth_events = context.auth_events
+            auth_events = context.current_state
 
         logger.debug(
             "_handle_new_event: %s, auth_events: %s",
diff --git a/synapse/state.py b/synapse/state.py
index 80cced351d..ba2500d61c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
 from synapse.util.expiringcache import ExpiringCache
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
+from synapse.api.auth import AuthEventTypes
 from synapse.events.snapshot import EventContext
 
 from collections import namedtuple
@@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-AuthEventTypes = (
-    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
-    EventTypes.JoinRules,
-)
-
-
 SIZE_OF_CACHE = 1000
 EVICTION_TIMEOUT_SECONDS = 20
 
@@ -139,18 +134,6 @@ class StateHandler(object):
             }
             context.state_group = None
 
-            if hasattr(event, "auth_events") and event.auth_events:
-                auth_ids = self.hs.get_auth().compute_auth_events(
-                    event, context.current_state
-                )
-                context.auth_events = {
-                    k: v
-                    for k, v in context.current_state.items()
-                    if v.event_id in auth_ids
-                }
-            else:
-                context.auth_events = {}
-
             if event.is_state():
                 key = (event.type, event.state_key)
                 if key in context.current_state:
@@ -187,18 +170,6 @@ class StateHandler(object):
                 replaces = context.current_state[key]
                 event.unsigned["replaces_state"] = replaces.event_id
 
-        if hasattr(event, "auth_events") and event.auth_events:
-            auth_ids = self.hs.get_auth().compute_auth_events(
-                event, context.current_state
-            )
-            context.auth_events = {
-                k: v
-                for k, v in context.current_state.items()
-                if v.event_id in auth_ids
-            }
-        else:
-            context.auth_events = {}
-
         context.prev_state_events = prev_state
         defer.returnValue(context)
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 71db16d0e5..456e4bd45d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
         if context.current_state is None:
             return
 
-        state_events = context.current_state
+        state_events = dict(context.current_state)
 
         if event.is_state():
             state_events[(event.type, event.state_key)] = event