From 17f4f14df7712426ffe0ddc3dc460820745de8a2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 13:28:31 +0100 Subject: Pull out event ids rather than full events for state --- synapse/storage/state.py | 55 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 17 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 0e8fa93e1f..fa40af6933 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -44,11 +44,7 @@ class StateStore(SQLBaseStore): """ @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - The return value is a dict mapping group names to lists of events. - """ + def get_state_groups_ids(self, room_id, event_ids): if not event_ids: defer.returnValue({}) @@ -59,9 +55,32 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups) + defer.returnValue(group_to_state) + + @defer.inlineCallbacks + def get_state_groups(self, room_id, event_ids): + """ Get the state groups for the given list of event_ids + + The return value is a dict mapping group names to lists of events. + """ + if not event_ids: + defer.returnValue({}) + + group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + + state_event_map = yield self.get_events( + [ + ev_id for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False + ) + defer.returnValue({ - group: state_map.values() - for group, state_map in group_to_state.items() + group: [ + state_event_map[v] for v in event_id_map.values() if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() }) def _store_mult_state_groups_txn(self, txn, events_and_contexts): @@ -248,8 +267,17 @@ class StateStore(SQLBaseStore): groups = set(event_to_groups.values()) group_to_state = yield self._get_state_for_groups(groups, types) + state_event_map = yield self.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False + ) + event_to_state = { - event_id: group_to_state[group] + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } for event_id, group in event_to_groups.items() } @@ -428,20 +456,13 @@ class StateStore(SQLBaseStore): full=(types is None), ) - state_events = yield self._get_events( - [ev_id for sd in results.values() for ev_id in sd.values()], - get_prev_content=False - ) - - state_events = {e.event_id: e for e in state_events} - # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. for group, state_dict in results.items(): results[group] = { - key: state_events[event_id] + key: event_id for key, event_id in state_dict.items() - if event_id and event_id in state_events + if event_id } defer.returnValue(results) -- cgit 1.4.1 From a3dc1e9cbe491aa981b8bbaeb2414b4ec8e5b9ca Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 17:32:22 +0100 Subject: Replace context.current_state with context.current_state_ids --- synapse/api/auth.py | 68 +++++++++----- synapse/events/snapshot.py | 13 +-- synapse/handlers/_base.py | 30 ++---- synapse/handlers/federation.py | 112 ++++++++++++---------- synapse/handlers/message.py | 91 +++++++++++++----- synapse/handlers/room_member.py | 124 +++++++++++++++++-------- synapse/push/action_generator.py | 4 +- synapse/push/bulk_push_rule_evaluator.py | 32 +++---- synapse/state.py | 48 +++++----- synapse/storage/push_rule.py | 21 +++-- synapse/storage/roommember.py | 45 ++++++++- synapse/storage/state.py | 16 ++-- synapse/visibility.py | 19 ++++ tests/replication/slave/storage/test_events.py | 9 +- tests/test_state.py | 73 ++++++--------- 15 files changed, 435 insertions(+), 270 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0db26fcfd7..40c3e9db0d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -52,7 +52,7 @@ class Auth(object): self.state = hs.get_state_handler() self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 # Docs for these currently lives at - # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst + # github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # In addition, we have type == delete_pusher which grants access only to # delete pushers. self._KNOWN_CAVEAT_PREFIXES = set([ @@ -63,6 +63,17 @@ class Auth(object): "user_id = ", ]) + @defer.inlineCallbacks + def check_from_context(self, event, context, do_sig_check=True): + auth_events_ids = yield self.compute_auth_events( + event, context.current_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + self.check(event, auth_events=auth_events, do_sig_check=False) + def check(self, event, auth_events, do_sig_check=True): """ Checks if this event is correctly authed. @@ -847,7 +858,7 @@ class Auth(object): @defer.inlineCallbacks def add_auth_events(self, builder, context): - auth_ids = self.compute_auth_events(builder, context.current_state) + auth_ids = yield self.compute_auth_events(builder, context.current_state_ids) auth_events_entries = yield self.store.add_event_hashes( auth_ids @@ -855,30 +866,32 @@ class Auth(object): builder.auth_events = auth_events_entries - def compute_auth_events(self, event, current_state): + @defer.inlineCallbacks + def compute_auth_events(self, event, current_state_ids, for_verification=False): if event.type == EventTypes.Create: - return [] + defer.returnValue([]) auth_ids = [] key = (EventTypes.PowerLevels, "", ) - power_level_event = current_state.get(key) + power_level_event_id = current_state_ids.get(key) - if power_level_event: - auth_ids.append(power_level_event.event_id) + if power_level_event_id: + auth_ids.append(power_level_event_id) key = (EventTypes.JoinRules, "", ) - join_rule_event = current_state.get(key) + join_rule_event_id = current_state_ids.get(key) key = (EventTypes.Member, event.user_id, ) - member_event = current_state.get(key) + member_event_id = current_state_ids.get(key) key = (EventTypes.Create, "", ) - create_event = current_state.get(key) - if create_event: - auth_ids.append(create_event.event_id) + create_event_id = current_state_ids.get(key) + if create_event_id: + auth_ids.append(create_event_id) - if join_rule_event: + if join_rule_event_id: + join_rule_event = yield self.store.get_event(join_rule_event_id) join_rule = join_rule_event.content.get("join_rule") is_public = join_rule == JoinRules.PUBLIC if join_rule else False else: @@ -887,15 +900,21 @@ class Auth(object): if event.type == EventTypes.Member: e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event: - auth_ids.append(join_rule_event.event_id) + if join_rule_event_id: + auth_ids.append(join_rule_event_id) if e_type == Membership.JOIN: - if member_event and not is_public: - auth_ids.append(member_event.event_id) + if member_event_id and not is_public: + auth_ids.append(member_event_id) else: - if member_event: - auth_ids.append(member_event.event_id) + if member_event_id: + auth_ids.append(member_event_id) + + if for_verification: + key = (EventTypes.Member, event.state_key, ) + existing_event_id = current_state_ids.get(key) + if existing_event_id: + auth_ids.append(existing_event_id) if e_type == Membership.INVITE: if "third_party_invite" in event.content: @@ -903,14 +922,15 @@ class Auth(object): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - third_party_invite = current_state.get(key) - if third_party_invite: - auth_ids.append(third_party_invite.event_id) - elif member_event: + third_party_invite_id = current_state_ids.get(key) + if third_party_invite_id: + auth_ids.append(third_party_invite_id) + elif member_event_id: + member_event = yield self.store.get_event(member_event_id) if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - return auth_ids + defer.returnValue(auth_ids) def _get_send_level(self, etype, state_key, auth_events): key = (EventTypes.PowerLevels, "", ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index cf11b4aa2e..c75afd02d8 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,17 +15,8 @@ class EventContext(object): - def _set_current_state(self, current_state): - if current_state is not None: - self.current_state_ids = {k: e.event_id for k, e in current_state.items()} - else: - self.current_state_ids = None - self._current_state = current_state - - current_state = property(lambda self: self._current_state, _set_current_state) - - def __init__(self, current_state=None): - self.current_state = current_state + def __init__(self, current_state_ids=None): + self.current_state_ids = current_state_ids self.state_group = None self.rejected = False self.push_actions = [] diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 11081a0cd5..e58735294e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -65,33 +65,21 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)), ) - def is_host_in_room(self, current_state): - room_members = [ - (state_key, event.membership) - for ((event_type, state_key), event) in current_state.items() - if event_type == EventTypes.Member - ] - if len(room_members) == 0: - # Have we just created the room, and is this about to be the very - # first member event? - create_event = current_state.get(("m.room.create", "")) - if create_event: - return True - for (state_key, membership) in room_members: - if ( - self.hs.is_mine_id(state_key) - and membership == Membership.JOIN - ): - return True - return False - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, current_state): + def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": + if context: + current_state = yield self.store.get_events( + context.current_state_ids.values() + ) + current_state = current_state.values() + else: + current_state = yield self.store.get_current_state(event.room_id) + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) @defer.inlineCallbacks diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 92679532b9..2b88e6550e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -217,11 +217,21 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + newly_joined = True + prev_state_id = context.current_state_ids.get( + (event.type, event.state_key) + ) + if prev_state_id: + prev_state = yield self.store.get_event( + prev_state_id, allow_none=True, + ) + if prev_state and prev_state.membership == Membership.JOIN: + newly_joined = False + + if newly_joined: user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) @@ -734,7 +744,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - self.auth.check(event, auth_events=context.current_state, do_sig_check=False) + yield self.auth.check_from_context(event, context, do_sig_check=False) defer.returnValue(event) @@ -782,18 +792,11 @@ class FederationHandler(BaseHandler): new_pdu = event - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(s.state_key)) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - + message_handler = self.hs.get_handlers().message_handler + destinations = yield message_handler.get_joined_hosts_for_room_from_state( + context + ) + destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -804,13 +807,15 @@ class FederationHandler(BaseHandler): self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = [e.event_id for e in context.current_state.values()] + state_ids = context.current_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) + state = yield self.store.get_events(context.current_state_ids.values()) + defer.returnValue({ - "state": context.current_state.values(), + "state": state.values(), "auth_chain": auth_chain, }) @@ -966,7 +971,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - self.auth.check(event, auth_events=context.current_state, do_sig_check=False) + yield self.auth.check_from_context(event, context, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e @@ -1010,18 +1015,11 @@ class FederationHandler(BaseHandler): new_pdu = event - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.LEAVE: - destinations.add(get_domain_from_id(s.state_key)) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - + message_handler = self.hs.get_handlers().message_handler + destinations = yield message_handler.get_joined_hosts_for_room_from_state( + context + ) + destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -1306,7 +1304,13 @@ class FederationHandler(BaseHandler): ) if not auth_events: - auth_events = context.current_state + auth_events_ids = yield self.auth.compute_auth_events( + event, context.current_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -1332,8 +1336,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR if event.type == EventTypes.GuestAccess: - full_context = yield self.store.get_current_state(room_id=event.room_id) - yield self.maybe_kick_guest_users(event, full_context) + yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1504,7 +1507,9 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state.update(auth_events) + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) context.state_group = None if different_auth and not event.internal_metadata.is_outlier(): @@ -1526,8 +1531,8 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. - auth_ids = self.auth.compute_auth_events( - event, context.current_state + auth_ids = yield self.auth.compute_auth_events( + event, context.current_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1583,7 +1588,9 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. - context.current_state.update(auth_events) + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) context.state_group = None try: @@ -1770,12 +1777,12 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, context.current_state) + yield self.auth.check_from_context(event, context) except AuthError as e: logger.warn("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, auth_events=context.current_state) + yield self._check_signature(event, context) member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(None, event, context) else: @@ -1801,11 +1808,11 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + self.auth.check_from_context(event, context) except AuthError as e: logger.warn("Denying third party invite %r because %s", event, e) raise e - yield self._check_signature(event, auth_events=context.current_state) + yield self._check_signature(event, context) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. @@ -1819,7 +1826,12 @@ class FederationHandler(BaseHandler): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - original_invite = context.current_state.get(key) + original_invite = None + original_invite_id = context.current_state_ids.get(key) + if original_invite_id: + original_invite = yield self.store.get_event( + original_invite_id, allow_none=True + ) if not original_invite: logger.info( "Could not find invite event for third_party_invite - " @@ -1836,13 +1848,13 @@ class FederationHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def _check_signature(self, event, auth_events): + def _check_signature(self, event, context): """ Checks that the signature in the event is consistent with its invite. Args: event (Event): The m.room.member event to check - auth_events (dict<(event type, state_key), event>): + context (EventContext): Raises: AuthError: if signature didn't match any keys, or key has been @@ -1853,10 +1865,14 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event = auth_events.get( + invite_event_id = context.current_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) + invite_event = None + if invite_event_id: + invite_event = yield self.store.get_event(invite_event_id, allow_none=True) + if not invite_event: raise AuthError(403, "Could not find invite") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4c3cd9d12e..e2f4387f60 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.metrics import measure_func +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -248,7 +249,7 @@ class MessageHandler(BaseHandler): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = self.deduplicate_state_event(event, context) + prev_state = yield self.deduplicate_state_event(event, context) if prev_state is not None: defer.returnValue(prev_state) @@ -263,6 +264,7 @@ class MessageHandler(BaseHandler): presence = self.hs.get_presence_handler() yield presence.bump_presence_active_time(user) + @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ Checks whether event is in the latest resolved state in context. @@ -270,13 +272,17 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event = context.current_state.get((event.type, event.state_key)) + prev_event_id = context.current_state_ids.get((event.type, event.state_key)) + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if not prev_event: + return + if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: - return prev_event - return None + defer.returnValue(prev_event) + return @defer.inlineCallbacks def create_and_send_nonmember_event( @@ -803,7 +809,7 @@ class MessageHandler(BaseHandler): logger.debug( "Created event %s with current state: %s", - event.event_id, context.current_state, + event.event_id, context.current_state_ids, ) defer.returnValue( @@ -826,12 +832,12 @@ class MessageHandler(BaseHandler): self.ratelimit(requester) try: - self.auth.check(event, auth_events=context.current_state) + yield self.auth.check_from_context(event, context) except AuthError as err: logger.warn("Denying new event %r because %s", event, err) raise err - yield self.maybe_kick_guest_users(event, context.current_state.values()) + yield self.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Check the alias is acually valid (at this time at least) @@ -859,6 +865,15 @@ class MessageHandler(BaseHandler): e.sender == event.sender ) + state_to_include_ids = [ + e_id + for k, e_id in context.current_state_ids.items() + if k[0] in self.hs.config.room_invite_state_types + or k[0] == EventTypes.Member and k[1] == event.sender + ] + + state_to_include = yield self.store.get_events(state_to_include_ids) + event.unsigned["invite_room_state"] = [ { "type": e.type, @@ -866,9 +881,7 @@ class MessageHandler(BaseHandler): "content": e.content, "sender": e.sender, } - for k, e in context.current_state.items() - if e.type in self.hs.config.room_invite_state_types - or is_inviter_member_event(e) + for e in state_to_include.values() ] invitee = UserID.from_string(event.state_key) @@ -890,7 +903,14 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Redaction: - if self.auth.check_redaction(event, auth_events=context.current_state): + auth_events_ids = yield self.auth.compute_auth_events( + event, context.current_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + if self.auth.check_redaction(event, auth_events=auth_events): original_event = yield self.store.get_event( event.redacts, check_redacted=False, @@ -904,7 +924,7 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.current_state: + if event.type == EventTypes.Create and context.current_state_ids: raise AuthError( 403, "Changing the room create event is forbidden", @@ -925,16 +945,7 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - destinations = set() - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(s.state_key)) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) + destinations = yield self.get_joined_hosts_for_room_from_state(context) @defer.inlineCallbacks def _notify(): @@ -952,3 +963,39 @@ class MessageHandler(BaseHandler): preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) + + def get_joined_hosts_for_room_from_state(self, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts_for_room_from_state( + state_group, context.current_state_ids + ) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids, + cache_context): + + # Don't bother getting state for people on the same HS + current_state = yield self.store.get_events([ + e_id for key, e_id in current_state_ids.items() + if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1]) + ]) + + destinations = set() + for e in current_state.itervalues(): + try: + if e.type == EventTypes.Member: + if e.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(e.state_key)) + except SynapseError: + logger.warn( + "Failed to get destination from event %s", e.event_id + ) + + defer.returnValue(destinations) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8b17632fdc..dd4b90ee24 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event = context.current_state.get( + prev_member_event_id = context.current_state_ids.get( (EventTypes.Member, target.to_string()), None ) if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: yield user_joined_room(self.distributor, target, room_id) elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target, room_id) + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): @@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts = [] latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - current_state = yield self.state_handler.get_current_state( + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) - old_state = current_state.get((EventTypes.Member, target.to_string())) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE - ) + old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + if old_state_id: + old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned" + " (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was banned" % (action,), + errcode=Codes.BAD_STATE + ) - is_host_in_room = self.is_host_in_room(current_state) + is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(current_state): + if requester.is_guest and not self._can_guest_join(current_state_ids): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") @@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler): requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) + prev_event = yield message_handler.deduplicate_state_event(event, context) if prev_event is not None: return if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") + if requester.is_guest: + guest_can_join = yield self._can_guest_join(context.current_state_ids) + if not guest_can_join: + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") yield message_handler.handle_new_client_event( requester, @@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), + prev_member_event_id = context.current_state_ids.get( + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: yield user_joined_room(self.distributor, target_user, room_id) elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) - def _can_guest_join(self, current_state): + @defer.inlineCallbacks + def _can_guest_join(self, current_state_ids): """ Returns whether a guest can join a room based on its current state. """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( + guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + if not guest_access_id: + defer.returnValue(False) + + guest_access = yield self.store.get_event(guest_access_id) + + defer.returnValue( guest_access and guest_access.content and "guest_access" in guest_access.content @@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler): if membership: yield self.store.forget(user_id, room_id) + + @defer.inlineCallbacks + def _is_host_in_room(self, current_state_ids): + # Have we just created the room, and is this about to be the very + # first member event? + create_event_id = current_state_ids.get(("m.room.create", "")) + if len(current_state_ids) == 1 and create_event_id: + defer.returnValue(self.hs.is_mine_id(create_event_id)) + + for (etype, state_key), event_id in current_state_ids.items(): + if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if not event: + continue + + if event.membership == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index ed2ccc4dfb..3f75d3f921 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -40,12 +40,12 @@ class ActionGenerator: def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "evaluator_for_event"): bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context.state_group, context.current_state + event, self.hs, self.store, context ) with Measure(self.clock, "action_for_event_by_user"): actions_by_user = yield bulk_evaluator.action_for_event_by_user( - event, context.current_state + event, context ) context.push_actions = [ diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 004eded61f..8d49beaec5 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,8 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent -from synapse.api.constants import EventTypes, Membership -from synapse.visibility import filter_events_for_clients +from synapse.api.constants import EventTypes +from synapse.visibility import filter_events_for_clients_context logger = logging.getLogger(__name__) @@ -36,9 +36,9 @@ def _get_rules(room_id, user_ids, store): @defer.inlineCallbacks -def evaluator_for_event(event, hs, store, state_group, current_state): +def evaluator_for_event(event, hs, store, context): rules_by_user = yield store.bulk_get_push_rules_for_room( - event.room_id, state_group, current_state + event.room_id, context ) # if this event is an invite event, we may need to run rules for the user @@ -72,7 +72,7 @@ class BulkPushRuleEvaluator: self.store = store @defer.inlineCallbacks - def action_for_event_by_user(self, event, current_state): + def action_for_event_by_user(self, event, context): actions_by_user = {} # None of these users can be peeking since this list of users comes @@ -82,27 +82,25 @@ class BulkPushRuleEvaluator: (u, False) for u in self.rules_by_user.keys() ] - filtered_by_user = yield filter_events_for_clients( - self.store, user_tuples, [event], {event.event_id: current_state} + filtered_by_user = yield filter_events_for_clients_context( + self.store, user_tuples, [event], {event.event_id: context} ) - room_members = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN + room_members = yield self.store.get_joined_users_from_context( + event.room_id, context, ) evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) condition_cache = {} - display_names = {} - for ev in current_state.values(): - nm = ev.content.get("displayname", None) - if nm and ev.type == EventTypes.Member: - display_names[ev.state_key] = nm - for uid, rules in self.rules_by_user.items(): - display_name = display_names.get(uid, None) + display_name = None + member_ev_id = context.current_state_ids.get((EventTypes.Member, uid)) + if member_ev_id: + member_ev = yield self.store.get_event(member_ev_id, allow_none=True) + if member_ev: + display_name = member_ev.content.get("displayname", None) filtered = filtered_by_user[uid] if len(filtered) == 0: diff --git a/synapse/state.py b/synapse/state.py index 2249b7fffb..2a01887a67 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -106,6 +106,20 @@ class StateHandler(object): defer.returnValue(state) + @defer.inlineCallbacks + def get_current_state_ids(self, room_id, event_type=None, state_key="", + latest_event_ids=None): + if not latest_event_ids: + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + + _, state = yield self.resolve_state_groups(room_id, latest_event_ids) + + if event_type: + defer.returnValue(state.get((event_type, state_key))) + return + + defer.returnValue(state) + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The @@ -127,27 +141,27 @@ class StateHandler(object): # state. Certainly store.get_current_state won't return any, and # persisting the event won't store the state group. if old_state: - context.current_state = { - (s.type, s.state_key): s for s in old_state + context.current_state_ids = { + (s.type, s.state_key): s.event_id for s in old_state } else: - context.current_state = {} + context.current_state_ids = {} context.prev_state_events = [] context.state_group = None defer.returnValue(context) if old_state: - context.current_state = { - (s.type, s.state_key): s for s in old_state + context.current_state_ids = { + (s.type, s.state_key): s.event_id for s in old_state } context.state_group = None if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state: - replaces = context.current_state[key] - if replaces.event_id != event.event_id: # Paranoia check - event.unsigned["replaces_state"] = replaces.event_id + if key in context.current_state_ids: + replaces = context.current_state_ids[key] + if replaces != event.event_id: # Paranoia check + event.unsigned["replaces_state"] = replaces context.prev_state_events = [] defer.returnValue(context) @@ -165,22 +179,14 @@ class StateHandler(object): group, curr_state = ret - state_map = yield self.store.get_events( - curr_state.values(), - get_prev_content=False - ) - curr_state = { - key: state_map[e_id] for key, e_id in curr_state.items() if e_id in state_map - } - - context.current_state = curr_state + context.current_state_ids = curr_state context.state_group = group if not event.is_state() else None if event.is_state(): key = (event.type, event.state_key) - if key in context.current_state: - replaces = context.current_state[key] - event.unsigned["replaces_state"] = replaces.event_id + if key in context.current_state_ids: + replaces = context.current_state_ids[key] + event.unsigned["replaces_state"] = replaces context.prev_state_events = [] defer.returnValue(context) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 78334a98cf..7e6ec411cd 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -124,7 +124,8 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(results) - def bulk_get_push_rules_for_room(self, room_id, state_group, current_state): + def bulk_get_push_rules_for_room(self, room_id, context): + state_group = context.state_group if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -132,10 +133,12 @@ class PushRuleStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._bulk_get_push_rules_for_room(room_id, state_group, current_state) + return self._bulk_get_push_rules_for_room( + room_id, state_group, context.current_state_ids + ) @cachedInlineCallbacks(num_args=2, cache_context=True) - def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state, + def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids, cache_context): # We don't use `state_group`, its there so that we can cache based # on it. However, its important that its never None, since two current_state's @@ -147,10 +150,16 @@ class PushRuleStore(SQLBaseStore): # their unread countss are correct in the event stream, but to avoid # generating them for bot / AS users etc, we only do so for people who've # sent a read receipt into the room. + local_user_member_ids = [ + e_id for (etype, state_key), e_id in current_state_ids.iteritems() + if etype == EventTypes.Member and self.hs.is_mine_id(state_key) + ] + + local_member_events = yield self._get_events(local_user_member_ids) + local_users_in_room = set( - e.state_key for e in current_state.values() - if e.type == EventTypes.Member and e.membership == Membership.JOIN - and self.hs.is_mine_id(e.state_key) + member_event.state_key for member_event in local_member_events + if member_event.membership == Membership.JOIN ) # users in the room who have pushers need to get push rules run because diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index a422ddf633..3ffad672a7 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -20,7 +20,7 @@ from collections import namedtuple from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks -from synapse.api.constants import Membership +from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging @@ -325,7 +325,8 @@ class RoomMemberStore(SQLBaseStore): @cachedInlineCallbacks(num_args=3) def was_forgotten_at(self, user_id, room_id, event_id): - """Returns whether user_id has elected to discard history for room_id at event_id. + """Returns whether user_id has elected to discard history for room_id at + event_id. event_id must be a membership event.""" def f(txn): @@ -358,3 +359,43 @@ class RoomMemberStore(SQLBaseStore): }, desc="who_forgot" ) + + def get_joined_users_from_context(self, room_id, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, context.current_state_ids + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids, + retcols=['user_id'], + keyvalues={ + "membership": Membership.JOIN, + } + ) + + defer.returnValue(set(row["user_id"] for row in rows)) diff --git a/synapse/storage/state.py b/synapse/storage/state.py index fa40af6933..22f7fb1aa1 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -89,17 +89,17 @@ class StateStore(SQLBaseStore): if event.internal_metadata.is_outlier(): continue - if context.current_state is None: + if context.current_state_ids is None: continue if context.state_group is not None: state_groups[event.event_id] = context.state_group continue - state_events = dict(context.current_state) + state_event_ids = dict(context.current_state_ids) if event.is_state(): - state_events[(event.type, event.state_key)] = event + state_event_ids[(event.type, event.state_key)] = event.event_id state_group = context.new_state_group_id @@ -119,12 +119,12 @@ class StateStore(SQLBaseStore): values=[ { "state_group": state_group, - "room_id": state.room_id, - "type": state.type, - "state_key": state.state_key, - "event_id": state.event_id, + "room_id": event.room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, } - for state in state_events.values() + for key, state_id in state_event_ids.items() ], ) state_groups[event.event_id] = state_group diff --git a/synapse/visibility.py b/synapse/visibility.py index cc12c0a23d..199b16d827 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -180,6 +180,25 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): }) +@defer.inlineCallbacks +def filter_events_for_clients_context(store, user_tuples, events, event_id_to_context): + user_ids = set(u[0] for u in user_tuples) + event_id_to_state = {} + for event_id, context in event_id_to_context.items(): + state = yield store.get_events([ + e_id + for key, e_id in context.current_state_ids.iteritems() + if key == (EventTypes.RoomHistoryVisibility, "") + or (key[0] == EventTypes.Member and key[1] in user_ids) + ]) + event_id_to_state[event_id] = state + + res = yield filter_events_for_clients( + store, user_tuples, events, event_id_to_state + ) + defer.returnValue(res) + + @defer.inlineCallbacks def filter_events_for_client(store, user_id, events, is_peeking=False): """ diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index f33e6f60fb..218cb24889 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -305,7 +305,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.event_id += 1 - context = EventContext(current_state=state) + if state is not None: + state_ids = { + key: e.event_id for key, e in state.items() + } + else: + state_ids = None + + context = EventContext(current_state_ids=state_ids) context.push_actions = push_actions ordering = None diff --git a/tests/test_state.py b/tests/test_state.py index 1a11bbcee0..df9362c985 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -69,7 +69,7 @@ class StateGroupStore(object): self._next_group = 1 - def get_state_groups(self, room_id, event_ids): + def get_state_groups_ids(self, room_id, event_ids): groups = {} for event_id in event_ids: group = self._event_to_state_group.get(event_id) @@ -79,20 +79,20 @@ class StateGroupStore(object): return defer.succeed(groups) def store_state_groups(self, event, context): - if context.current_state is None: + if context.current_state_ids is None: return - state_events = context.current_state + state_events = dict(context.current_state_ids) if event.is_state(): - state_events[(event.type, event.state_key)] = event + state_events[(event.type, event.state_key)] = event.event_id state_group = context.state_group if not state_group: state_group = self._next_group self._next_group += 1 - self._group_to_state[state_group] = state_events.values() + self._group_to_state[state_group] = state_events self._event_to_state_group[event.event_id] = state_group @@ -136,7 +136,7 @@ class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ - "get_state_groups", + "get_state_groups_ids", "add_event_hashes", ] ) @@ -187,7 +187,7 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids context_store = {} @@ -196,7 +196,7 @@ class StateTestCase(unittest.TestCase): store.store_state_groups(event, context) context_store[event.event_id] = context - self.assertEqual(2, len(context_store["D"].current_state)) + self.assertEqual(2, len(context_store["D"].current_state_ids)) @defer.inlineCallbacks def test_branch_basic_conflict(self): @@ -239,7 +239,7 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids context_store = {} @@ -303,7 +303,7 @@ class StateTestCase(unittest.TestCase): ) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids context_store = {} @@ -384,7 +384,7 @@ class StateTestCase(unittest.TestCase): graph = Graph(nodes, edges) store = StateGroupStore() - self.store.get_state_groups.side_effect = store.get_state_groups + self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids context_store = {} @@ -424,13 +424,8 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( - set(old_state), set(context.current_state.values()) + set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) self.assertIsNone(context.state_group) @@ -449,14 +444,8 @@ class StateTestCase(unittest.TestCase): event, old_state=old_state ) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( - set(old_state), - set(context.current_state.values()) + set(e.event_id for e in old_state), set(context.current_state_ids.values()) ) self.assertIsNone(context.state_group) @@ -473,20 +462,15 @@ class StateTestCase(unittest.TestCase): group_name = "group_name_1" - self.store.get_state_groups.return_value = { - group_name: old_state, + self.store.get_state_groups_ids.return_value = { + group_name: {(e.type, e.state_key): e.event_id for e in old_state}, } context = yield self.state.compute_event_context(event) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in context.current_state.values()]) + set(context.current_state_ids.values()) ) self.assertEqual(group_name, context.state_group) @@ -503,20 +487,15 @@ class StateTestCase(unittest.TestCase): group_name = "group_name_1" - self.store.get_state_groups.return_value = { - group_name: old_state, + self.store.get_state_groups_ids.return_value = { + group_name: {(e.type, e.state_key): e.event_id for e in old_state}, } context = yield self.state.compute_event_context(event) - for k, v in context.current_state.items(): - type, state_key = k - self.assertEqual(type, v.type) - self.assertEqual(state_key, v.state_key) - self.assertEqual( set([e.event_id for e in old_state]), - set([e.event_id for e in context.current_state.values()]) + set(context.current_state_ids.values()) ) self.assertIsNone(context.state_group) @@ -545,7 +524,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 6) + self.assertEqual(len(context.current_state_ids), 6) self.assertIsNone(context.state_group) @@ -573,7 +552,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(len(context.current_state), 6) + self.assertEqual(len(context.current_state_ids), 6) self.assertIsNone(context.state_group) @@ -608,7 +587,7 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_2[2], context.current_state[("test1", "1")]) + self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")]) # Reverse the depth to make sure we are actually using the depths # during state resolution. @@ -627,15 +606,15 @@ class StateTestCase(unittest.TestCase): context = yield self._get_context(event, old_state_1, old_state_2) - self.assertEqual(old_state_1[2], context.current_state[("test1", "1")]) + self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")]) def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, + self.store.get_state_groups_ids.return_value = { + group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1}, + group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2}, } return self.state.compute_event_context(event) -- cgit 1.4.1 From 721414d98af55f6527b1dd1fa77ebe25c05ecfe2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 17:49:05 +0100 Subject: Add desc --- synapse/storage/roommember.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'synapse/storage') diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 3ffad672a7..2cab065bca 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -395,7 +395,8 @@ class RoomMemberStore(SQLBaseStore): retcols=['user_id'], keyvalues={ "membership": Membership.JOIN, - } + }, + desc="_get_joined_users_from_context", ) defer.returnValue(set(row["user_id"] for row in rows)) -- cgit 1.4.1 From 778fa85f4714c528e73a50af2c1c5fa4f30573eb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 25 Aug 2016 18:59:44 +0100 Subject: Make sync not pull out full state --- synapse/handlers/sync.py | 75 ++++++++++++++++++++++++++---------------------- synapse/storage/state.py | 33 +++++++++++++++++++++ 2 files changed, 74 insertions(+), 34 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8dfd02e7b..5cd009a1c8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -355,11 +355,11 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state = yield self.store.get_state_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event(event.event_id) if event.is_state(): - state = state.copy() - state[(event.type, event.state_key)] = event - defer.returnValue(state) + state_ids = state_ids.copy() + state_ids[(event.type, event.state_key)] = event.event_id + defer.returnValue(state_ids) @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): @@ -412,57 +412,61 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state = yield self.store.get_state_for_event( + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) else: - current_state = yield self.get_state_at( + current_state_ids = yield self.get_state_at( room_id, stream_position=now_token ) - state = current_state + state_ids = current_state_ids timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, - timeline_start=state, + timeline_start=state_ids, previous={}, - current=current_state, + current=current_state_ids, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state_at_timeline_start = yield self.store.get_state_for_event( + state_at_timeline_start = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, - current=current_state, + current=current_state_ids, ) else: - state = {} + state_ids = {} + + state = {} + if state_ids: + state = yield self.store.get_events(state_ids.values()) defer.returnValue({ (e.type, e.state_key): e @@ -766,8 +770,13 @@ class SyncHandler(object): # the last sync (even if we have since left). This is to make sure # we do send down the room, and with full state, where necessary if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) @@ -1059,27 +1068,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): Returns: dict """ - event_id_to_state = { - e.event_id: e - for e in itertools.chain( - timeline_contains.values(), - previous.values(), - timeline_start.values(), - current.values(), + event_id_to_key = { + e: key + for key, e in itertools.chain( + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(e.event_id for e in current.values()) - tc_ids = set(e.event_id for e in timeline_contains.values()) - p_ids = set(e.event_id for e in previous.values()) - ts_ids = set(e.event_id for e in timeline_start.values()) + c_ids = set(e for e in current.values()) + tc_ids = set(e for e in timeline_contains.values()) + p_ids = set(e for e in previous.values()) + ts_ids = set(e for e in timeline_start.values()) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - evs = (event_id_to_state[e] for e in state_ids) return { - (e.type, e.state_key): e - for e in evs + event_id_to_key[e]: e for e in state_ids } diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 22f7fb1aa1..b1d461fef5 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -283,6 +283,22 @@ class StateStore(SQLBaseStore): defer.returnValue({event: event_to_state[event] for event in event_ids}) + @defer.inlineCallbacks + def get_state_ids_for_events(self, event_ids, types): + event_to_groups = yield self._get_state_group_for_events( + event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups, types) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + defer.returnValue({event: event_to_state[event] for event in event_ids}) + @defer.inlineCallbacks def get_state_for_event(self, event_id, types=None): """ @@ -300,6 +316,23 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_for_events([event_id], types) defer.returnValue(state_map[event_id]) + @defer.inlineCallbacks + def get_state_ids_for_event(self, event_id, types=None): + """ + Get the state dict corresponding to a particular event + + Args: + event_id(str): event whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from (type, state_key) -> state_event + """ + state_map = yield self.get_state_ids_for_events([event_id], types) + defer.returnValue(state_map[event_id]) + @cached(num_args=2, max_entries=10000) def _get_state_group_for_event(self, room_id, event_id): return self._simple_select_one_onecol( -- cgit 1.4.1 From 1ccdc1e93a5ae854fa89751a78c9103940a9f9e6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Aug 2016 10:59:40 +0100 Subject: Cache check_host_in_room --- synapse/api/auth.py | 20 ++++++-------------- synapse/storage/roommember.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 14 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 597631a88f..f26e585623 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -279,22 +279,14 @@ class Auth(object): @defer.inlineCallbacks def check_host_in_room(self, room_id, host): with Measure(self.clock, "check_host_in_room"): - curr_state_id = yield self.state.get_current_state_ids(room_id) + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - for (etype, state_key), event_id in curr_state_id.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - event = yield self.store.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) + group, curr_state_ids = yield self.state.resolve_state_groups( + room_id, latest_event_ids + ) - defer.returnValue(False) + ret = yield self.store.is_host_joined(room_id, host, group, curr_state_ids) + defer.returnValue(ret) def check_event_sender_in_room(self, event, auth_events): key = (EventTypes.Member, event.user_id, ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 2cab065bca..5ce5e8da37 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -400,3 +400,38 @@ class RoomMemberStore(SQLBaseStore): ) defer.returnValue(set(row["user_id"] for row in rows)) + + def is_host_joined(self, room_id, host, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _is_host_joined(self, room_id, host, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + if get_domain_from_id(state_key) != host: + continue + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) -- cgit 1.4.1 From 4daa397a00b1b7080686fde34a3858342e4b0498 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 26 Aug 2016 13:02:08 +0100 Subject: Add is_host_joined to slave storage --- synapse/replication/slave/storage/events.py | 2 ++ synapse/storage/roommember.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) (limited to 'synapse/storage') diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 3a71e7b292..65e982a0ce 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -127,6 +127,8 @@ class SlavedEventStore(BaseSlavedStore): get_room_events_stream_for_rooms = ( DataStore.get_room_events_stream_for_rooms.__func__ ) + is_host_joined = DataStore.is_host_joined.__func__ + _is_host_joined = RoomMemberStore.__dict__["_is_host_joined"] get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ _set_before_and_after = staticmethod(DataStore._set_before_and_after) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5ce5e8da37..5f15200c20 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -396,6 +396,7 @@ class RoomMemberStore(SQLBaseStore): keyvalues={ "membership": Membership.JOIN, }, + batch_size=1000, desc="_get_joined_users_from_context", ) @@ -409,8 +410,8 @@ class RoomMemberStore(SQLBaseStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_users_from_context( - room_id, state_group, state_ids + return self._is_host_joined( + room_id, host, state_group, state_ids ) @cachedInlineCallbacks(num_args=3) @@ -430,7 +431,7 @@ class RoomMemberStore(SQLBaseStore): logger.warn("state_key not user_id: %s", state_key) continue - event = yield self.store.get_event(event_id, allow_none=True) + event = yield self.get_event(event_id, allow_none=True) if event and event.content["membership"] == Membership.JOIN: defer.returnValue(True) -- cgit 1.4.1