summary refs log tree commit diff
path: root/synapse/api/auth.py
diff options
context:
space:
mode:
authorMark Haines <mark.haines@matrix.org>2016-08-26 14:35:31 +0100
committerMark Haines <mark.haines@matrix.org>2016-08-26 14:35:31 +0100
commit4bbef62124f0fb249e314e94a1b9c15204c8daa9 (patch)
treecc2e7458da160915f7224c0487f185e2433eaadb /synapse/api/auth.py
parentMore 0_0 in tests (diff)
parentMerge pull request #1048 from matrix-org/erikj/fix_mail_name (diff)
downloadsynapse-4bbef62124f0fb249e314e94a1b9c15204c8daa9.tar.xz
Merge remote-tracking branch 'origin/develop' into markjh/direct_to_device
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py88
1 files changed, 51 insertions, 37 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 0db26fcfd7..f26e585623 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -52,7 +52,7 @@ class Auth(object):
         self.state = hs.get_state_handler()
         self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
         # Docs for these currently lives at
-        # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
+        # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
         # In addition, we have type == delete_pusher which grants access only to
         # delete pushers.
         self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,6 +63,17 @@ class Auth(object):
             "user_id = ",
         ])
 
+    @defer.inlineCallbacks
+    def check_from_context(self, event, context, do_sig_check=True):
+        auth_events_ids = yield self.compute_auth_events(
+            event, context.current_state_ids, for_verification=True,
+        )
+        auth_events = yield self.store.get_events(auth_events_ids)
+        auth_events = {
+            (e.type, e.state_key): e for e in auth_events.values()
+        }
+        self.check(event, auth_events=auth_events, do_sig_check=False)
+
     def check(self, event, auth_events, do_sig_check=True):
         """ Checks if this event is correctly authed.
 
@@ -267,21 +278,15 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def check_host_in_room(self, room_id, host):
-        curr_state = yield self.state.get_current_state(room_id)
-
-        for event in curr_state.values():
-            if event.type == EventTypes.Member:
-                try:
-                    if get_domain_from_id(event.state_key) != host:
-                        continue
-                except:
-                    logger.warn("state_key not user_id: %s", event.state_key)
-                    continue
+        with Measure(self.clock, "check_host_in_room"):
+            latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-                if event.content["membership"] == Membership.JOIN:
-                    defer.returnValue(True)
+            group, curr_state_ids = yield self.state.resolve_state_groups(
+                room_id, latest_event_ids
+            )
 
-        defer.returnValue(False)
+            ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids)
+            defer.returnValue(ret)
 
     def check_event_sender_in_room(self, event, auth_events):
         key = (EventTypes.Member, event.user_id, )
@@ -847,7 +852,7 @@ class Auth(object):
 
     @defer.inlineCallbacks
     def add_auth_events(self, builder, context):
-        auth_ids = self.compute_auth_events(builder, context.current_state)
+        auth_ids = yield self.compute_auth_events(builder, context.current_state_ids)
 
         auth_events_entries = yield self.store.add_event_hashes(
             auth_ids
@@ -855,30 +860,32 @@ class Auth(object):
 
         builder.auth_events = auth_events_entries
 
-    def compute_auth_events(self, event, current_state):
+    @defer.inlineCallbacks
+    def compute_auth_events(self, event, current_state_ids, for_verification=False):
         if event.type == EventTypes.Create:
-            return []
+            defer.returnValue([])
 
         auth_ids = []
 
         key = (EventTypes.PowerLevels, "", )
-        power_level_event = current_state.get(key)
+        power_level_event_id = current_state_ids.get(key)
 
-        if power_level_event:
-            auth_ids.append(power_level_event.event_id)
+        if power_level_event_id:
+            auth_ids.append(power_level_event_id)
 
         key = (EventTypes.JoinRules, "", )
-        join_rule_event = current_state.get(key)
+        join_rule_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Member, event.user_id, )
-        member_event = current_state.get(key)
+        member_event_id = current_state_ids.get(key)
 
         key = (EventTypes.Create, "", )
-        create_event = current_state.get(key)
-        if create_event:
-            auth_ids.append(create_event.event_id)
+        create_event_id = current_state_ids.get(key)
+        if create_event_id:
+            auth_ids.append(create_event_id)
 
-        if join_rule_event:
+        if join_rule_event_id:
+            join_rule_event = yield self.store.get_event(join_rule_event_id)
             join_rule = join_rule_event.content.get("join_rule")
             is_public = join_rule == JoinRules.PUBLIC if join_rule else False
         else:
@@ -887,15 +894,21 @@ class Auth(object):
         if event.type == EventTypes.Member:
             e_type = event.content["membership"]
             if e_type in [Membership.JOIN, Membership.INVITE]:
-                if join_rule_event:
-                    auth_ids.append(join_rule_event.event_id)
+                if join_rule_event_id:
+                    auth_ids.append(join_rule_event_id)
 
             if e_type == Membership.JOIN:
-                if member_event and not is_public:
-                    auth_ids.append(member_event.event_id)
+                if member_event_id and not is_public:
+                    auth_ids.append(member_event_id)
             else:
-                if member_event:
-                    auth_ids.append(member_event.event_id)
+                if member_event_id:
+                    auth_ids.append(member_event_id)
+
+                if for_verification:
+                    key = (EventTypes.Member, event.state_key, )
+                    existing_event_id = current_state_ids.get(key)
+                    if existing_event_id:
+                        auth_ids.append(existing_event_id)
 
             if e_type == Membership.INVITE:
                 if "third_party_invite" in event.content:
@@ -903,14 +916,15 @@ class Auth(object):
                         EventTypes.ThirdPartyInvite,
                         event.content["third_party_invite"]["signed"]["token"]
                     )
-                    third_party_invite = current_state.get(key)
-                    if third_party_invite:
-                        auth_ids.append(third_party_invite.event_id)
-        elif member_event:
+                    third_party_invite_id = current_state_ids.get(key)
+                    if third_party_invite_id:
+                        auth_ids.append(third_party_invite_id)
+        elif member_event_id:
+            member_event = yield self.store.get_event(member_event_id)
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
 
-        return auth_ids
+        defer.returnValue(auth_ids)
 
     def _get_send_level(self, etype, state_key, auth_events):
         key = (EventTypes.PowerLevels, "", )