summary refs log tree commit diff
path: root/synapse/api/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api/auth.py')
-rw-r--r--synapse/api/auth.py175
1 files changed, 80 insertions, 95 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 64f605b962..d5bf0be85c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -18,9 +18,8 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership, JoinRules
-from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor
 from synapse.types import UserID, ClientInfo
 
 import logging
@@ -40,6 +39,7 @@ class Auth(object):
         self.hs = hs
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
+        self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
 
     def check(self, event, auth_events):
         """ Checks if this event is correctly authed.
@@ -64,7 +64,10 @@ class Auth(object):
             if event.type == EventTypes.Aliases:
                 return True
 
-            logger.debug("Auth events: %s", auth_events)
+            logger.debug(
+                "Auth events: %s",
+                [a.event_id for a in auth_events.values()]
+            )
 
             if event.type == EventTypes.Member:
                 allowed = self.is_membership_change_allowed(
@@ -183,18 +186,10 @@ class Auth(object):
         else:
             join_rule = JoinRules.INVITE
 
-        user_level = self._get_power_level_from_event_state(
-            event,
-            event.user_id,
-            auth_events,
-        )
+        user_level = self._get_user_power_level(event.user_id, auth_events)
 
-        ban_level, kick_level, redact_level = (
-            self._get_ops_level_from_event_state(
-                event,
-                auth_events,
-            )
-        )
+        # FIXME (erikj): What should we do here as the default?
+        ban_level = self._get_named_level(auth_events, "ban", 50)
 
         logger.debug(
             "is_membership_change_allowed: %s",
@@ -210,28 +205,33 @@ class Auth(object):
             }
         )
 
-        if ban_level:
-            ban_level = int(ban_level)
-        else:
-            ban_level = 50  # FIXME (erikj): What should we do here?
+        if Membership.JOIN != membership:
+            # JOIN is the only action you can perform if you're not in the room
+            if not caller_in_room:  # caller isn't joined
+                raise AuthError(
+                    403,
+                    "%s not in room %s." % (event.user_id, event.room_id,)
+                )
 
         if Membership.INVITE == membership:
             # TODO (erikj): We should probably handle this more intelligently
             # PRIVATE join rules.
 
             # Invites are valid iff caller is in the room and target isn't.
-            if not caller_in_room:  # caller isn't joined
-                raise AuthError(
-                    403,
-                    "%s not in room %s." % (event.user_id, event.room_id,)
-                )
-            elif target_banned:
+            if target_banned:
                 raise AuthError(
                     403, "%s is banned from the room" % (target_user_id,)
                 )
             elif target_in_room:  # the target is already in the room.
                 raise AuthError(403, "%s is already in the room." %
                                      target_user_id)
+            else:
+                invite_level = self._get_named_level(auth_events, "invite", 0)
+
+                if user_level < invite_level:
+                    raise AuthError(
+                        403, "You cannot invite user %s." % target_user_id
+                    )
         elif Membership.JOIN == membership:
             # Joins are valid iff caller == target and they were:
             # invited: They are accepting the invitation
@@ -251,21 +251,12 @@ class Auth(object):
                 raise AuthError(403, "You are not allowed to join this room")
         elif Membership.LEAVE == membership:
             # TODO (erikj): Implement kicks.
-
-            if not caller_in_room:  # trying to leave a room you aren't joined
-                raise AuthError(
-                    403,
-                    "%s not in room %s." % (target_user_id, event.room_id,)
-                )
-            elif target_banned and user_level < ban_level:
+            if target_banned and user_level < ban_level:
                 raise AuthError(
                     403, "You cannot unban user &s." % (target_user_id,)
                 )
             elif target_user_id != event.user_id:
-                if kick_level:
-                    kick_level = int(kick_level)
-                else:
-                    kick_level = 50  # FIXME (erikj): What should we do here?
+                kick_level = self._get_named_level(auth_events, "kick", 50)
 
                 if user_level < kick_level:
                     raise AuthError(
@@ -279,34 +270,42 @@ class Auth(object):
 
         return True
 
-    def _get_power_level_from_event_state(self, event, user_id, auth_events):
+    def _get_power_level_event(self, auth_events):
         key = (EventTypes.PowerLevels, "", )
-        power_level_event = auth_events.get(key)
-        level = None
+        return auth_events.get(key)
+
+    def _get_user_power_level(self, user_id, auth_events):
+        power_level_event = self._get_power_level_event(auth_events)
+
         if power_level_event:
             level = power_level_event.content.get("users", {}).get(user_id)
             if not level:
                 level = power_level_event.content.get("users_default", 0)
+
+            if level is None:
+                return 0
+            else:
+                return int(level)
         else:
             key = (EventTypes.Create, "", )
             create_event = auth_events.get(key)
             if (create_event is not None and
                     create_event.content["creator"] == user_id):
                 return 100
+            else:
+                return 0
 
-        return level
+    def _get_named_level(self, auth_events, name, default):
+        power_level_event = self._get_power_level_event(auth_events)
 
-    def _get_ops_level_from_event_state(self, event, auth_events):
-        key = (EventTypes.PowerLevels, "", )
-        power_level_event = auth_events.get(key)
+        if not power_level_event:
+            return default
 
-        if power_level_event:
-            return (
-                power_level_event.content.get("ban", 50),
-                power_level_event.content.get("kick", 50),
-                power_level_event.content.get("redact", 50),
-            )
-        return None, None, None,
+        level = power_level_event.content.get(name, None)
+        if level is not None:
+            return int(level)
+        else:
+            return default
 
     @defer.inlineCallbacks
     def get_user_by_req(self, request):
@@ -363,7 +362,7 @@ class Auth(object):
                 default=[""]
             )[0]
             if user and access_token and ip_addr:
-                yield self.store.insert_client_ip(
+                self.store.insert_client_ip(
                     user=user,
                     access_token=access_token,
                     device_id=user_info["device_id"],
@@ -373,7 +372,10 @@ class Auth(object):
 
             defer.returnValue((user, ClientInfo(device_id, token_id)))
         except KeyError:
-            raise AuthError(403, "Missing access token.")
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
+                errcode=Codes.MISSING_TOKEN
+            )
 
     @defer.inlineCallbacks
     def get_user_by_token(self, token):
@@ -387,21 +389,20 @@ class Auth(object):
         Raises:
             AuthError if no user by that token exists or the token is invalid.
         """
-        try:
-            ret = yield self.store.get_user_by_token(token)
-            if not ret:
-                raise StoreError(400, "Unknown token")
-            user_info = {
-                "admin": bool(ret.get("admin", False)),
-                "device_id": ret.get("device_id"),
-                "user": UserID.from_string(ret.get("name")),
-                "token_id": ret.get("token_id", None),
-            }
+        ret = yield self.store.get_user_by_token(token)
+        if not ret:
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
+                errcode=Codes.UNKNOWN_TOKEN
+            )
+        user_info = {
+            "admin": bool(ret.get("admin", False)),
+            "device_id": ret.get("device_id"),
+            "user": UserID.from_string(ret.get("name")),
+            "token_id": ret.get("token_id", None),
+        }
 
-            defer.returnValue(user_info)
-        except StoreError:
-            raise AuthError(403, "Unrecognised access token.",
-                            errcode=Codes.UNKNOWN_TOKEN)
+        defer.returnValue(user_info)
 
     @defer.inlineCallbacks
     def get_appservice_by_req(self, request):
@@ -409,19 +410,22 @@ class Auth(object):
             token = request.args["access_token"][0]
             service = yield self.store.get_app_service_by_token(token)
             if not service:
-                raise AuthError(403, "Unrecognised access token.",
-                                errcode=Codes.UNKNOWN_TOKEN)
+                raise AuthError(
+                    self.TOKEN_NOT_FOUND_HTTP_STATUS,
+                    "Unrecognised access token.",
+                    errcode=Codes.UNKNOWN_TOKEN
+                )
             defer.returnValue(service)
         except KeyError:
-            raise AuthError(403, "Missing access token.")
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
+            )
 
     def is_server_admin(self, user):
         return self.store.is_server_admin(user)
 
     @defer.inlineCallbacks
     def add_auth_events(self, builder, context):
-        yield run_on_reactor()
-
         auth_ids = self.compute_auth_events(builder, context.current_state)
 
         auth_events_entries = yield self.store.add_event_hashes(
@@ -486,7 +490,7 @@ class Auth(object):
             send_level = send_level_event.content.get("events", {}).get(
                 event.type
             )
-            if not send_level:
+            if send_level is None:
                 if hasattr(event, "state_key"):
                     send_level = send_level_event.content.get(
                         "state_default", 50
@@ -501,16 +505,7 @@ class Auth(object):
         else:
             send_level = 0
 
-        user_level = self._get_power_level_from_event_state(
-            event,
-            event.user_id,
-            auth_events,
-        )
-
-        if user_level:
-            user_level = int(user_level)
-        else:
-            user_level = 0
+        user_level = self._get_user_power_level(event.user_id, auth_events)
 
         if user_level < send_level:
             raise AuthError(
@@ -542,16 +537,9 @@ class Auth(object):
         return True
 
     def _check_redaction(self, event, auth_events):
-        user_level = self._get_power_level_from_event_state(
-            event,
-            event.user_id,
-            auth_events,
-        )
+        user_level = self._get_user_power_level(event.user_id, auth_events)
 
-        _, _, redact_level = self._get_ops_level_from_event_state(
-            event,
-            auth_events,
-        )
+        redact_level = self._get_named_level(auth_events, "redact", 50)
 
         if user_level < redact_level:
             raise AuthError(
@@ -579,11 +567,7 @@ class Auth(object):
         if not current_state:
             return
 
-        user_level = self._get_power_level_from_event_state(
-            event,
-            event.user_id,
-            auth_events,
-        )
+        user_level = self._get_user_power_level(event.user_id, auth_events)
 
         # Check other levels:
         levels_to_check = [
@@ -592,6 +576,7 @@ class Auth(object):
             ("ban", []),
             ("redact", []),
             ("kick", []),
+            ("invite", []),
         ]
 
         old_list = current_state.content.get("users")