From a3dc1e9cbe491aa981b8bbaeb2414b4ec8e5b9ca Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 17:32:22 +0100 Subject: Replace context.current_state with context.current_state_ids --- synapse/api/auth.py | 68 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 44 insertions(+), 24 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0db26fcfd7..40c3e9db0d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -52,7 +52,7 @@ class Auth(object): self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 # Docs for these currently lives at - # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst + # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # In addition, we have type == delete_pusher which grants access only to # delete pushers. self._KNOWN_CAVEAT_PREFIXES = set([ @@ -63,6 +63,17 @@ class Auth(object): "user_id = ", ]) + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.current_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=False) + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. @@ -847,7 +858,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = self.compute_auth_events(builder, context.current_state) + auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -855,30 +866,32 @@ class Auth(object): builder.auth_events = auth_events_entries - def compute_auth_events(self, event, current_state): + @defer.inlineCallbacks + def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - return [] + defer.returnValue([]) auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = current_state.get(key) + power_level_event_id = current_state_ids.get(key) - if power_level_event: - auth_ids.append(power_level_event.event_id) + if power_level_event_id: + auth_ids.append(power_level_event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = current_state.get(key) + join_rule_event_id = current_state_ids.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = current_state.get(key) + member_event_id = current_state_ids.get(key) key = (EventTypes.Create, "", ) - create_event = current_state.get(key) - if create_event: - auth_ids.append(create_event.event_id) + create_event_id = current_state_ids.get(key) + if create_event_id: + auth_ids.append(create_event_id) - if join_rule_event: + if join_rule_event_id: + join_rule_event = yield self.store.get_event(join_rule_event_id) join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False else: @@ -887,15 +900,21 @@ class Auth(object): if event.type == EventTypes.Member: e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event: - auth_ids.append(join_rule_event.event_id) + if join_rule_event_id: + auth_ids.append(join_rule_event_id) if e_type == Membership.JOIN: - if member_event and not is_public: - auth_ids.append(member_event.event_id) + if member_event_id and not is_public: + auth_ids.append(member_event_id) else: - if member_event: - auth_ids.append(member_event.event_id) + if member_event_id: + auth_ids.append(member_event_id) + + if for_verification: + key = (EventTypes.Member, event.state_key, ) + existing_event_id = current_state_ids.get(key) + if existing_event_id: + auth_ids.append(existing_event_id) if e_type == Membership.INVITE: if "third_party_invite" in event.content: @@ -903,14 +922,15 @@ class Auth(object): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - third_party_invite = current_state.get(key) - if third_party_invite: - auth_ids.append(third_party_invite.event_id) - elif member_event: + third_party_invite_id = current_state_ids.get(key) + if third_party_invite_id: + auth_ids.append(third_party_invite_id) + elif member_event_id: + member_event = yield self.store.get_event(member_event_id) if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - return auth_ids + defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) -- cgit 1.5.1 From 0e1900d8193a612d6920a9eca0aec4813e17d355 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 18:15:51 +0100 Subject: Pull out full state less --- synapse/api/auth.py | 13 +++++++------ synapse/state.py | 12 ++++++++---- 2 files changed, 15 insertions(+), 10 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 40c3e9db0d..23a928de16 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -278,18 +278,19 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): - curr_state = yield self.state.get_current_state(room_id) + curr_state_id = yield self.state.get_current_state_ids(room_id) - for event in curr_state.values(): - if event.type == EventTypes.Member: + for (etype, state_key), event_id in curr_state_id.items(): + if etype == EventTypes.Member: try: - if get_domain_from_id(event.state_key) != host: + if get_domain_from_id(state_key) != host: continue except: - logger.warn("state_key not user_id: %s", event.state_key) + logger.warn("state_key not user_id: %s", state_key) continue - if event.content["membership"] == Membership.JOIN: + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: defer.returnValue(True) defer.returnValue(False) diff --git a/synapse/state.py b/synapse/state.py index 2a01887a67..78461215ca 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -95,15 +95,19 @@ class StateHandler(object): _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + if event_type: + event_id = state.get((event_type, state_key)) + event = None + if event_id: + event = yield self.store.get_event(event_id, allow_none=True) + defer.returnValue(event) + return + state_map = yield self.store.get_events(state.values(), get_prev_content=False) state = { key: state_map[e_id] for key, e_id in state.items() if e_id in state_map } - if event_type: - defer.returnValue(state.get((event_type, state_key))) - return - defer.returnValue(state) @defer.inlineCallbacks -- cgit 1.5.1 From 25414b44a2edbf7cc66e46968e81b56ae32f2887 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Aug 2016 10:47:00 +0100 Subject: Add measure on check_host_in_room --- synapse/api/auth.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 23a928de16..597631a88f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -278,20 +278,21 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): - curr_state_id = yield self.state.get_current_state_ids(room_id) + with Measure(self.clock, "check_host_in_room"): + curr_state_id = yield self.state.get_current_state_ids(room_id) - for (etype, state_key), event_id in curr_state_id.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: + for (etype, state_key), event_id in curr_state_id.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - event = yield self.store.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) defer.returnValue(False) -- cgit 1.5.1 From 1ccdc1e93a5ae854fa89751a78c9103940a9f9e6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Aug 2016 10:59:40 +0100 Subject: Cache check_host_in_room --- synapse/api/auth.py | 20 ++++++-------------- synapse/storage/roommember.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 14 deletions(-) (limited to 'synapse/api') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 597631a88f..f26e585623 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -279,22 +279,14 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - curr_state_id = yield self.state.get_current_state_ids(room_id) + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - for (etype, state_key), event_id in curr_state_id.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - event = yield self.store.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + group, curr_state_ids = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - defer.returnValue(False) + ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) + defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): key = (EventTypes.Member, event.user_id, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 2cab065bca..5ce5e8da37 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -400,3 +400,38 @@ class RoomMemberStore(SQLBaseStore): ) defer.returnValue(set(row["user_id"] for row in rows)) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) -- cgit 1.5.1