summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py64
-rw-r--r--synapse/api/ratelimiting.py14
2 files changed, 47 insertions, 31 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 5029f7c534..69b3392735 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -72,7 +72,7 @@ class Auth(object):
         auth_events = {
             (e.type, e.state_key): e for e in auth_events.values()
         }
-        self.check(event, auth_events=auth_events, do_sig_check=False)
+        self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
 
     def check(self, event, auth_events, do_sig_check=True):
         """ Checks if this event is correctly authed.
@@ -91,11 +91,28 @@ class Auth(object):
             if not hasattr(event, "room_id"):
                 raise AuthError(500, "Event has no room_id: %s" % event)
 
-            sender_domain = get_domain_from_id(event.sender)
+            if do_sig_check:
+                sender_domain = get_domain_from_id(event.sender)
+                event_id_domain = get_domain_from_id(event.event_id)
+
+                is_invite_via_3pid = (
+                    event.type == EventTypes.Member
+                    and event.membership == Membership.INVITE
+                    and "third_party_invite" in event.content
+                )
 
-            # Check the sender's domain has signed the event
-            if do_sig_check and not event.signatures.get(sender_domain):
-                raise AuthError(403, "Event not signed by sending server")
+                # Check the sender's domain has signed the event
+                if not event.signatures.get(sender_domain):
+                    # We allow invites via 3pid to have a sender from a different
+                    # HS, as the sender must match the sender of the original
+                    # 3pid invite. This is checked further down with the
+                    # other dedicated membership checks.
+                    if not is_invite_via_3pid:
+                        raise AuthError(403, "Event not signed by sender's server")
+
+                # Check the event_id's domain has signed the event
+                if not event.signatures.get(event_id_domain):
+                    raise AuthError(403, "Event not signed by sending server")
 
             if auth_events is None:
                 # Oh, we don't know what the state of the room was, so we
@@ -491,6 +508,9 @@ class Auth(object):
         if not invite_event:
             return False
 
+        if invite_event.sender != event.sender:
+            return False
+
         if event.user_id != invite_event.user_id:
             return False
 
@@ -583,10 +603,12 @@ class Auth(object):
         """
         # Can optionally look elsewhere in the request (e.g. headers)
         try:
-            user_id = yield self._get_appservice_user_id(request)
+            user_id, app_service = yield self._get_appservice_user_id(request)
             if user_id:
                 request.authenticated_entity = user_id
-                defer.returnValue(synapse.types.create_requester(user_id))
+                defer.returnValue(
+                    synapse.types.create_requester(user_id, app_service=app_service)
+                )
 
             access_token = get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@@ -624,7 +646,8 @@ class Auth(object):
             request.authenticated_entity = user.to_string()
 
             defer.returnValue(synapse.types.create_requester(
-                user, token_id, is_guest, device_id))
+                user, token_id, is_guest, device_id, app_service=app_service)
+            )
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -633,20 +656,20 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def _get_appservice_user_id(self, request):
-        app_service = yield self.store.get_app_service_by_token(
+        app_service = self.store.get_app_service_by_token(
             get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
             )
         )
         if app_service is None:
-            defer.returnValue(None)
+            defer.returnValue((None, None))
 
         if "user_id" not in request.args:
-            defer.returnValue(app_service.sender)
+            defer.returnValue((app_service.sender, app_service))
 
         user_id = request.args["user_id"][0]
         if app_service.sender == user_id:
-            defer.returnValue(app_service.sender)
+            defer.returnValue((app_service.sender, app_service))
 
         if not app_service.is_interested_in_user(user_id):
             raise AuthError(
@@ -658,7 +681,7 @@ class Auth(object):
                 403,
                 "Application service has not registered this user"
             )
-        defer.returnValue(user_id)
+        defer.returnValue((user_id, app_service))
 
     @defer.inlineCallbacks
     def get_user_by_access_token(self, token, rights="access"):
@@ -835,13 +858,12 @@ class Auth(object):
         }
         defer.returnValue(user_info)
 
-    @defer.inlineCallbacks
     def get_appservice_by_req(self, request):
         try:
             token = get_access_token_from_request(
                 request, self.TOKEN_NOT_FOUND_HTTP_STATUS
             )
-            service = yield self.store.get_app_service_by_token(token)
+            service = self.store.get_app_service_by_token(token)
             if not service:
                 logger.warn("Unrecognised appservice access token: %s" % (token,))
                 raise AuthError(
@@ -850,7 +872,7 @@ class Auth(object):
                     errcode=Codes.UNKNOWN_TOKEN
                 )
             request.authenticated_entity = service.sender
-            defer.returnValue(service)
+            return defer.succeed(service)
         except KeyError:
             raise AuthError(
                 self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
@@ -982,16 +1004,6 @@ class Auth(object):
                         403,
                         "You are not allowed to set others state"
                     )
-                else:
-                    sender_domain = UserID.from_string(
-                        event.user_id
-                    ).domain
-
-                    if sender_domain != event.state_key:
-                        raise AuthError(
-                            403,
-                            "You are not allowed to set others state"
-                        )
 
         return True
 
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 660dfb56e5..06cc8d90b8 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -23,7 +23,7 @@ class Ratelimiter(object):
     def __init__(self):
         self.message_counts = collections.OrderedDict()
 
-    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
+    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
         """Can the user send a message?
         Args:
             user_id: The user sending a message.
@@ -32,12 +32,15 @@ class Ratelimiter(object):
                 second.
             burst_count: How many messages the user can send before being
                 limited.
+            update (bool): Whether to update the message rates or not. This is
+                useful to check if a message would be allowed to be sent before
+                its ready to be actually sent.
         Returns:
             A pair of a bool indicating if they can send a message now and a
                 time in seconds of when they can next send a message.
         """
         self.prune_message_counts(time_now_s)
-        message_count, time_start, _ignored = self.message_counts.pop(
+        message_count, time_start, _ignored = self.message_counts.get(
             user_id, (0., time_now_s, None),
         )
         time_delta = time_now_s - time_start
@@ -52,9 +55,10 @@ class Ratelimiter(object):
             allowed = True
             message_count += 1
 
-        self.message_counts[user_id] = (
-            message_count, time_start, msg_rate_hz
-        )
+        if update:
+            self.message_counts[user_id] = (
+                message_count, time_start, msg_rate_hz
+            )
 
         if msg_rate_hz > 0:
             time_allowed = (