diff options
-rw-r--r-- | synapse/api/auth.py | 20 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 35 |
2 files changed, 41 insertions, 14 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 597631a88f..f26e585623 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -279,22 +279,14 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - curr_state_id = yield self.state.get_current_state_ids(room_id) + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - for (etype, state_key), event_id in curr_state_id.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - event = yield self.store.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + group, curr_state_ids = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - defer.returnValue(False) + ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) + defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): key = (EventTypes.Member, event.user_id, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 2cab065bca..5ce5e8da37 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -400,3 +400,38 @@ class RoomMemberStore(SQLBaseStore): ) defer.returnValue(set(row["user_id"] for row in rows)) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) |