summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py80
-rw-r--r--synapse/api/ratelimiting.py14
2 files changed, 59 insertions, 35 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e75fd518be..69b3392735 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -603,10 +603,12 @@ class Auth(object):
         """
         # Can optionally look elsewhere in the request (e.g. headers)
         try:
-            user_id = yield self._get_appservice_user_id(request)
+            user_id, app_service = yield self._get_appservice_user_id(request)
             if user_id:
                 request.authenticated_entity = user_id
-                defer.returnValue(synapse.types.create_requester(user_id))
+                defer.returnValue(
+                    synapse.types.create_requester(user_id, app_service=app_service)
+                )
 
             access_token = get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@@ -644,7 +646,8 @@ class Auth(object):
             request.authenticated_entity = user.to_string()
 
             defer.returnValue(synapse.types.create_requester(
-                user, token_id, is_guest, device_id))
+                user, token_id, is_guest, device_id, app_service=app_service)
+            )
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -653,20 +656,20 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def _get_appservice_user_id(self, request):
-        app_service = yield self.store.get_app_service_by_token(
+        app_service = self.store.get_app_service_by_token(
             get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
             )
         )
         if app_service is None:
-            defer.returnValue(None)
+            defer.returnValue((None, None))
 
         if "user_id" not in request.args:
-            defer.returnValue(app_service.sender)
+            defer.returnValue((app_service.sender, app_service))
 
         user_id = request.args["user_id"][0]
         if app_service.sender == user_id:
-            defer.returnValue(app_service.sender)
+            defer.returnValue((app_service.sender, app_service))
 
         if not app_service.is_interested_in_user(user_id):
             raise AuthError(
@@ -678,7 +681,7 @@ class Auth(object):
                 403,
                 "Application service has not registered this user"
             )
-        defer.returnValue(user_id)
+        defer.returnValue((user_id, app_service))
 
     @defer.inlineCallbacks
     def get_user_by_access_token(self, token, rights="access"):
@@ -855,13 +858,12 @@ class Auth(object):
         }
         defer.returnValue(user_info)
 
-    @defer.inlineCallbacks
     def get_appservice_by_req(self, request):
         try:
             token = get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
             )
-            service = yield self.store.get_app_service_by_token(token)
+            service = self.store.get_app_service_by_token(token)
             if not service:
                 logger.warn("Unrecognised appservice access token: %s" % (token,))
                 raise AuthError(
@@ -870,7 +872,7 @@ class Auth(object):
                     errcode=Codes.UNKNOWN_TOKEN
                 )
             request.authenticated_entity = service.sender
-            defer.returnValue(service)
+            return defer.succeed(service)
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
@@ -1002,16 +1004,6 @@ class Auth(object):
                         403,
                         "You are not allowed to set others state"
                     )
-                else:
-                    sender_domain = UserID.from_string(
-                        event.user_id
-                    ).domain
-
-                    if sender_domain != event.state_key:
-                        raise AuthError(
-                            403,
-                            "You are not allowed to set others state"
-                        )
 
         return True
 
@@ -1178,7 +1170,8 @@ def has_access_token(request):
         bool: False if no access_token was given, True otherwise.
     """
     query_params = request.args.get("access_token")
-    return bool(query_params)
+    auth_headers = request.requestHeaders.getRawHeaders("Authorization")
+    return bool(query_params) or bool(auth_headers)
 
 
 def get_access_token_from_request(request, token_not_found_http_status=401):
@@ -1196,13 +1189,40 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
     Raises:
         AuthError: If there isn't an access_token in the request.
     """
+
+    auth_headers = request.requestHeaders.getRawHeaders("Authorization")
     query_params = request.args.get("access_token")
-    # Try to get the access_token from the query params.
-    if not query_params:
-        raise AuthError(
-            token_not_found_http_status,
-            "Missing access token.",
-            errcode=Codes.MISSING_TOKEN
-        )
+    if auth_headers:
+        # Try the get the access_token from a "Authorization: Bearer"
+        # header
+        if query_params is not None:
+            raise AuthError(
+                token_not_found_http_status,
+                "Mixing Authorization headers and access_token query parameters.",
+                errcode=Codes.MISSING_TOKEN,
+            )
+        if len(auth_headers) > 1:
+            raise AuthError(
+                token_not_found_http_status,
+                "Too many Authorization headers.",
+                errcode=Codes.MISSING_TOKEN,
+            )
+        parts = auth_headers[0].split(" ")
+        if parts[0] == "Bearer" and len(parts) == 2:
+            return parts[1]
+        else:
+            raise AuthError(
+                token_not_found_http_status,
+                "Invalid Authorization header.",
+                errcode=Codes.MISSING_TOKEN,
+            )
+    else:
+        # Try to get the access_token from the query params.
+        if not query_params:
+            raise AuthError(
+                token_not_found_http_status,
+                "Missing access token.",
+                errcode=Codes.MISSING_TOKEN
+            )
 
-    return query_params[0]
+        return query_params[0]
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 660dfb56e5..06cc8d90b8 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -23,7 +23,7 @@ class Ratelimiter(object):
     def __init__(self):
         self.message_counts = collections.OrderedDict()
 
-    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
+    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
         """Can the user send a message?
         Args:
             user_id: The user sending a message.
@@ -32,12 +32,15 @@ class Ratelimiter(object):
                 second.
             burst_count: How many messages the user can send before being
                 limited.
+            update (bool): Whether to update the message rates or not. This is
+                useful to check if a message would be allowed to be sent before
+                its ready to be actually sent.
         Returns:
             A pair of a bool indicating if they can send a message now and a
                 time in seconds of when they can next send a message.
         """
         self.prune_message_counts(time_now_s)
-        message_count, time_start, _ignored = self.message_counts.pop(
+        message_count, time_start, _ignored = self.message_counts.get(
             user_id, (0., time_now_s, None),
         )
         time_delta = time_now_s - time_start
@@ -52,9 +55,10 @@ class Ratelimiter(object):
             allowed = True
             message_count += 1
 
-        self.message_counts[user_id] = (
-            message_count, time_start, msg_rate_hz
-        )
+        if update:
+            self.message_counts[user_id] = (
+                message_count, time_start, msg_rate_hz
+            )
 
         if msg_rate_hz > 0:
             time_allowed = (