From a3dc1e9cbe491aa981b8bbaeb2414b4ec8e5b9ca Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 17:32:22 +0100 Subject: Replace context.current_state with context.current_state_ids --- synapse/api/auth.py | 68 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 24 deletions(-) (limited to 'synapse/api/auth.py') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0db26fcfd7..40c3e9db0d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -52,7 +52,7 @@ class Auth(object): self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 # Docs for these currently lives at - # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst + # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # In addition, we have type == delete_pusher which grants access only to # delete pushers. self._KNOWN_CAVEAT_PREFIXES = set([ @@ -63,6 +63,17 @@ class Auth(object): "user_id = ", ]) + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.current_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=False) + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. @@ -847,7 +858,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = self.compute_auth_events(builder, context.current_state) + auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -855,30 +866,32 @@ class Auth(object): builder.auth_events = auth_events_entries - def compute_auth_events(self, event, current_state): + @defer.inlineCallbacks + def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - return [] + defer.returnValue([]) auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = current_state.get(key) + power_level_event_id = current_state_ids.get(key) - if power_level_event: - auth_ids.append(power_level_event.event_id) + if power_level_event_id: + auth_ids.append(power_level_event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = current_state.get(key) + join_rule_event_id = current_state_ids.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = current_state.get(key) + member_event_id = current_state_ids.get(key) key = (EventTypes.Create, "", ) - create_event = current_state.get(key) - if create_event: - auth_ids.append(create_event.event_id) + create_event_id = current_state_ids.get(key) + if create_event_id: + auth_ids.append(create_event_id) - if join_rule_event: + if join_rule_event_id: + join_rule_event = yield self.store.get_event(join_rule_event_id) join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False else: @@ -887,15 +900,21 @@ class Auth(object): if event.type == EventTypes.Member: e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event: - auth_ids.append(join_rule_event.event_id) + if join_rule_event_id: + auth_ids.append(join_rule_event_id) if e_type == Membership.JOIN: - if member_event and not is_public: - auth_ids.append(member_event.event_id) + if member_event_id and not is_public: + auth_ids.append(member_event_id) else: - if member_event: - auth_ids.append(member_event.event_id) + if member_event_id: + auth_ids.append(member_event_id) + + if for_verification: + key = (EventTypes.Member, event.state_key, ) + existing_event_id = current_state_ids.get(key) + if existing_event_id: + auth_ids.append(existing_event_id) if e_type == Membership.INVITE: if "third_party_invite" in event.content: @@ -903,14 +922,15 @@ class Auth(object): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - third_party_invite = current_state.get(key) - if third_party_invite: - auth_ids.append(third_party_invite.event_id) - elif member_event: + third_party_invite_id = current_state_ids.get(key) + if third_party_invite_id: + auth_ids.append(third_party_invite_id) + elif member_event_id: + member_event = yield self.store.get_event(member_event_id) if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - return auth_ids + defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) -- cgit 1.4.1