diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/_base.py | 30 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 10 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 8 | ||||
-rw-r--r-- | synapse/handlers/events.py | 3 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 227 | ||||
-rw-r--r-- | synapse/handlers/message.py | 93 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 48 | ||||
-rw-r--r-- | synapse/handlers/receipts.py | 5 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 130 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 126 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 9 |
11 files changed, 498 insertions, 191 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 11081a0cd5..e58735294e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -65,33 +65,21 @@ class BaseHandler(object): retry_after_ms=int(1000 * (time_allowed - time_now)), ) - def is_host_in_room(self, current_state): - room_members = [ - (state_key, event.membership) - for ((event_type, state_key), event) in current_state.items() - if event_type == EventTypes.Member - ] - if len(room_members) == 0: - # Have we just created the room, and is this about to be the very - # first member event? - create_event = current_state.get(("m.room.create", "")) - if create_event: - return True - for (state_key, membership) in room_members: - if ( - self.hs.is_mine_id(state_key) - and membership == Membership.JOIN - ): - return True - return False - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, current_state): + def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": + if context: + current_state = yield self.store.get_events( + context.current_state_ids.values() + ) + current_state = current_state.values() + else: + current_state = yield self.store.get_current_state(event.room_id) + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) @defer.inlineCallbacks diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 306686a384..b440280b74 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -176,6 +176,16 @@ class ApplicationServicesHandler(object): defer.returnValue(ret) @defer.inlineCallbacks + def get_3pe_protocols(self): + services = yield self.store.get_app_services() + protocols = {} + for s in services: + for p in s.protocols: + protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p) + + defer.returnValue(protocols) + + @defer.inlineCallbacks def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4bea7f2b19..14352985e2 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,7 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, get_domain_from_id import logging import string @@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) if not servers: raise SynapseError(400, "Failed to get server list") @@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND ) - extra_servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + extra_servers = set(get_domain_from_id(u) for u in users) servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 3a3a1257d3..d3685fb12a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.store.get_users_in_room(event.room_id) + users = yield self.state.get_current_user_in_room(event.room_id) states = yield presence_handler.get_states( users, as_event=True, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 01a761715b..dc90a5dde4 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -29,6 +29,7 @@ from synapse.util import unwrapFirstError from synapse.util.logcontext import ( PreserveLoggingContext, preserve_fn, preserve_context_over_deferred ) +from synapse.util.metrics import measure_func from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze @@ -100,6 +101,9 @@ class FederationHandler(BaseHandler): def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + auth_chain and state are None if we already have the necessary state + and prev_events in the db """ event = pdu @@ -117,12 +121,21 @@ class FederationHandler(BaseHandler): # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work - is_in_room = yield self.auth.check_host_in_room( - event.room_id, - self.server_name - ) - if not is_in_room and not event.internal_metadata.is_outlier(): - logger.debug("Got event for room we're not in.") + # If state and auth_chain are None, then we don't need to do this check + # as we already know we have enough state in the DB to handle this + # event. + if state and auth_chain and not event.internal_metadata.is_outlier(): + is_in_room = yield self.auth.check_host_in_room( + event.room_id, + self.server_name + ) + else: + is_in_room = True + if not is_in_room: + logger.info( + "Got event for room we're not in: %r %r", + event.room_id, event.event_id + ) try: event_stream_id, max_stream_id = yield self._persist_auth_tree( @@ -217,17 +230,28 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + newly_joined = True + prev_state_id = context.prev_state_ids.get( + (event.type, event.state_key) + ) + if prev_state_id: + prev_state = yield self.store.get_event( + prev_state_id, allow_none=True, + ) + if prev_state and prev_state.membership == Membership.JOIN: + newly_joined = False + + if newly_joined: user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) + @measure_func("_filter_events_for_server") @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): - event_to_state = yield self.store.get_state_for_events( + event_to_state_ids = yield self.store.get_state_ids_for_events( frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -235,6 +259,30 @@ class FederationHandler(BaseHandler): ) ) + # We only want to pull out member events that correspond to the + # server's domain. + + def check_match(id): + try: + return server_name == get_domain_from_id(id) + except: + return False + + event_map = yield self.store.get_events([ + e_id for key_to_eid in event_to_state_ids.values() + for key, e_id in key_to_eid + if key[0] != EventTypes.Member or check_match(key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in key_to_eid.items() + if inner_e_id in event_map + } + for e_id, key_to_eid in event_to_state_ids.items() + } + def redact_disallowed(event, state): if not state: return event @@ -377,7 +425,9 @@ class FederationHandler(BaseHandler): )).addErrback(unwrapFirstError) auth_events.update({a.event_id: a for a in results if a}) required_auth.update( - a_id for event in results for a_id, _ in event.auth_events if event + a_id + for event in results if event + for a_id, _ in event.auth_events ) missing_auth = required_auth - set(auth_events) @@ -560,6 +610,18 @@ class FederationHandler(BaseHandler): ])) states = dict(zip(event_ids, [s[1] for s in states])) + state_map = yield self.store.get_events( + [e_id for ids in states.values() for e_id in ids], + get_prev_content=False + ) + states = { + key: { + k: state_map[e_id] + for k, e_id in state_dict.items() + if e_id in state_map + } for key, state_dict in states.items() + } + for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) @@ -722,7 +784,7 @@ class FederationHandler(BaseHandler): # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_join_request` - self.auth.check(event, auth_events=context.current_state, do_sig_check=False) + yield self.auth.check_from_context(event, context, do_sig_check=False) defer.returnValue(event) @@ -770,18 +832,11 @@ class FederationHandler(BaseHandler): new_pdu = event - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(s.state_key)) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - + message_handler = self.hs.get_handlers().message_handler + destinations = yield message_handler.get_joined_hosts_for_room_from_state( + context + ) + destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -792,13 +847,15 @@ class FederationHandler(BaseHandler): self.replication_layer.send_pdu(new_pdu, destinations) - state_ids = [e.event_id for e in context.current_state.values()] + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) + state = yield self.store.get_events(context.prev_state_ids.values()) + defer.returnValue({ - "state": context.current_state.values(), + "state": state.values(), "auth_chain": auth_chain, }) @@ -954,7 +1011,7 @@ class FederationHandler(BaseHandler): try: # The remote hasn't signed it yet, obviously. We'll do the full checks # when we get the event back in `on_send_leave_request` - self.auth.check(event, auth_events=context.current_state, do_sig_check=False) + yield self.auth.check_from_context(event, context, do_sig_check=False) except AuthError as e: logger.warn("Failed to create new leave %r because %s", event, e) raise e @@ -998,18 +1055,11 @@ class FederationHandler(BaseHandler): new_pdu = event - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.LEAVE: - destinations.add(get_domain_from_id(s.state_key)) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - + message_handler = self.hs.get_handlers().message_handler + destinations = yield message_handler.get_joined_hosts_for_room_from_state( + context + ) + destinations = set(destinations) destinations.discard(origin) logger.debug( @@ -1024,6 +1074,8 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def get_state_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ yield run_on_reactor() state_groups = yield self.store.get_state_groups( @@ -1065,6 +1117,34 @@ class FederationHandler(BaseHandler): defer.returnValue([]) @defer.inlineCallbacks + def get_state_ids_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ + yield run_on_reactor() + + state_groups = yield self.store.get_state_groups_ids( + room_id, [event_id] + ) + + if state_groups: + _, state = state_groups.items().pop() + results = state + + event = yield self.store.get_event(event_id) + if event and event.is_state(): + # Get previous state + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + if prev_id != event.event_id: + results[(event.type, event.state_key)] = prev_id + else: + del results[(event.type, event.state_key)] + + defer.returnValue(results.values()) + else: + defer.returnValue([]) + + @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, pdu_list, limit): in_room = yield self.auth.check_host_in_room(room_id, origin) @@ -1294,7 +1374,13 @@ class FederationHandler(BaseHandler): ) if not auth_events: - auth_events = context.current_state + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -1320,8 +1406,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR if event.type == EventTypes.GuestAccess: - full_context = yield self.store.get_current_state(room_id=event.room_id) - yield self.maybe_kick_guest_users(event, full_context) + yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1389,6 +1474,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1492,8 +1582,14 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1514,8 +1610,8 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. - auth_ids = self.auth.compute_auth_events( - event, context.current_state + auth_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1571,8 +1667,14 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) @@ -1758,12 +1860,12 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, context.current_state) + yield self.auth.check_from_context(event, context) except AuthError as e: logger.warn("Denying new third party invite %r because %s", event, e) raise e - yield self._check_signature(event, auth_events=context.current_state) + yield self._check_signature(event, context) member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(None, event, context) else: @@ -1789,11 +1891,11 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(event, auth_events=context.current_state) + self.auth.check_from_context(event, context) except AuthError as e: logger.warn("Denying third party invite %r because %s", event, e) raise e - yield self._check_signature(event, auth_events=context.current_state) + yield self._check_signature(event, context) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. @@ -1807,7 +1909,12 @@ class FederationHandler(BaseHandler): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - original_invite = context.current_state.get(key) + original_invite = None + original_invite_id = context.prev_state_ids.get(key) + if original_invite_id: + original_invite = yield self.store.get_event( + original_invite_id, allow_none=True + ) if not original_invite: logger.info( "Could not find invite event for third_party_invite - " @@ -1824,13 +1931,13 @@ class FederationHandler(BaseHandler): defer.returnValue((event, context)) @defer.inlineCallbacks - def _check_signature(self, event, auth_events): + def _check_signature(self, event, context): """ Checks that the signature in the event is consistent with its invite. Args: event (Event): The m.room.member event to check - auth_events (dict<(event type, state_key), event>): + context (EventContext): Raises: AuthError: if signature didn't match any keys, or key has been @@ -1841,10 +1948,14 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event = auth_events.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) + invite_event = None + if invite_event_id: + invite_event = yield self.store.get_event(invite_event_id, allow_none=True) + if not invite_event: raise AuthError(403, "Could not find invite") diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4c3cd9d12e..3577db0595 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.metrics import measure_func +from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -248,7 +249,7 @@ class MessageHandler(BaseHandler): assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = self.deduplicate_state_event(event, context) + prev_state = yield self.deduplicate_state_event(event, context) if prev_state is not None: defer.returnValue(prev_state) @@ -263,6 +264,7 @@ class MessageHandler(BaseHandler): presence = self.hs.get_presence_handler() yield presence.bump_presence_active_time(user) + @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ Checks whether event is in the latest resolved state in context. @@ -270,13 +272,17 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event = context.current_state.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if not prev_event: + return + if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: - return prev_event - return None + defer.returnValue(prev_event) + return @defer.inlineCallbacks def create_and_send_nonmember_event( @@ -802,8 +808,8 @@ class MessageHandler(BaseHandler): event = builder.build() logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state, + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, ) defer.returnValue( @@ -826,12 +832,12 @@ class MessageHandler(BaseHandler): self.ratelimit(requester) try: - self.auth.check(event, auth_events=context.current_state) + yield self.auth.check_from_context(event, context) except AuthError as err: logger.warn("Denying new event %r because %s", event, err) raise err - yield self.maybe_kick_guest_users(event, context.current_state.values()) + yield self.maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: # Check the alias is acually valid (at this time at least) @@ -859,6 +865,15 @@ class MessageHandler(BaseHandler): e.sender == event.sender ) + state_to_include_ids = [ + e_id + for k, e_id in context.current_state_ids.items() + if k[0] in self.hs.config.room_invite_state_types + or k[0] == EventTypes.Member and k[1] == event.sender + ] + + state_to_include = yield self.store.get_events(state_to_include_ids) + event.unsigned["invite_room_state"] = [ { "type": e.type, @@ -866,9 +881,7 @@ class MessageHandler(BaseHandler): "content": e.content, "sender": e.sender, } - for k, e in context.current_state.items() - if e.type in self.hs.config.room_invite_state_types - or is_inviter_member_event(e) + for e in state_to_include.values() ] invitee = UserID.from_string(event.state_key) @@ -890,7 +903,14 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Redaction: - if self.auth.check_redaction(event, auth_events=context.current_state): + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + if self.auth.check_redaction(event, auth_events=auth_events): original_event = yield self.store.get_event( event.redacts, check_redacted=False, @@ -904,7 +924,7 @@ class MessageHandler(BaseHandler): "You don't have permission to redact events" ) - if event.type == EventTypes.Create and context.current_state: + if event.type == EventTypes.Create and context.prev_state_ids: raise AuthError( 403, "Changing the room create event is forbidden", @@ -925,16 +945,7 @@ class MessageHandler(BaseHandler): event_stream_id, max_stream_id ) - destinations = set() - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add(get_domain_from_id(s.state_key)) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) + destinations = yield self.get_joined_hosts_for_room_from_state(context) @defer.inlineCallbacks def _notify(): @@ -952,3 +963,39 @@ class MessageHandler(BaseHandler): preserve_fn(federation_handler.handle_new_event)( event, destinations=destinations, ) + + def get_joined_hosts_for_room_from_state(self, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts_for_room_from_state( + state_group, context.current_state_ids + ) + + @cachedInlineCallbacks(num_args=1, cache_context=True) + def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids, + cache_context): + + # Don't bother getting state for people on the same HS + current_state = yield self.store.get_events([ + e_id for key, e_id in current_state_ids.items() + if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1]) + ]) + + destinations = set() + for e in current_state.itervalues(): + try: + if e.type == EventTypes.Member: + if e.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(e.state_key)) + except SynapseError: + logger.warn( + "Failed to get destination from event %s", e.event_id + ) + + defer.returnValue(destinations) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6a1fe76c88..cf82a2336e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -88,6 +88,8 @@ class PresenceHandler(object): self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() + self.state = hs.get_state_handler() + self.federation.register_edu_handler( "m.presence", self.incoming_presence ) @@ -189,6 +191,13 @@ class PresenceHandler(object): 5000, ) + self.clock.call_later( + 60, + self.clock.looping_call, + self._persist_unpersisted_changes, + 60 * 1000, + ) + metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) @defer.inlineCallbacks @@ -215,6 +224,27 @@ class PresenceHandler(object): logger.info("Finished _on_shutdown") @defer.inlineCallbacks + def _persist_unpersisted_changes(self): + """We periodically persist the unpersisted changes, as otherwise they + may stack up and slow down shutdown times. + """ + logger.info( + "Performing _persist_unpersisted_changes. Persiting %d unpersisted changes", + len(self.unpersisted_users_changes) + ) + + unpersisted = self.unpersisted_users_changes + self.unpersisted_users_changes = set() + + if unpersisted: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in unpersisted + ]) + + logger.info("Finished _persist_unpersisted_changes") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -532,7 +562,9 @@ class PresenceHandler(object): if not local_states: continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + hosts = set(get_domain_from_id(u) for u in users) + for host in hosts: hosts_to_states.setdefault(host, []).extend(local_states) @@ -725,13 +757,13 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. + user_ids = yield self.state.get_current_user_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = yield self.store.get_joined_hosts_for_room(room_id) + hosts = set(get_domain_from_id(u) for u in user_ids) self._push_to_remotes({host: (state,) for host in hosts}) else: - user_ids = yield self.store.get_users_in_room(room_id) user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) @@ -918,7 +950,12 @@ def should_notify(old_state, new_state): if new_state.currently_active != old_state.currently_active: return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Only notify about last active bumps if we're not currently acive + if not (old_state.currently_active and new_state.currently_active): + return True + + elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. return True @@ -955,6 +992,7 @@ class PresenceEventSource(object): self.get_presence_handler = hs.get_presence_handler self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -1017,7 +1055,7 @@ class PresenceEventSource(object): user_ids_to_check = set() for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) user_ids_to_check.update(users) user_ids_to_check.update(friends) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e62722d78d..726f7308d2 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from ._base import BaseHandler from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import get_domain_from_id import logging @@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler): "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() + self.state = hs.get_state_handler() @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, @@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler): event_ids = receipt["event_ids"] data = receipt["data"] - remotedomains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) remotedomains = remotedomains.copy() remotedomains.discard(self.server_name) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 8b17632fdc..ba49075a20 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,12 @@ class RoomMemberHandler(BaseHandler): prev_event_ids=prev_event_ids, ) + # Check if this event matches the previous membership event for the user. + duplicate = yield msg_handler.deduplicate_state_event(event, context) + if duplicate is not None: + # Discard the new event since this membership change is a no-op. + return + yield msg_handler.handle_new_client_event( requester, event, @@ -93,20 +99,26 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event = context.current_state.get( + prev_member_event_id = context.prev_state_ids.get( (EventTypes.Member, target.to_string()), None ) if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: yield user_joined_room(self.distributor, target, room_id) elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target, room_id) + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) @defer.inlineCallbacks def remote_join(self, remote_room_hosts, room_id, user, content): @@ -195,29 +207,32 @@ class RoomMemberHandler(BaseHandler): remote_room_hosts = [] latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) - current_state = yield self.state_handler.get_current_state( + current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, ) - old_state = current_state.get((EventTypes.Member, target.to_string())) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was banned" % (action,), - errcode=Codes.BAD_STATE - ) + old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + if old_state_id: + old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned" + " (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was banned" % (action,), + errcode=Codes.BAD_STATE + ) - is_host_in_room = self.is_host_in_room(current_state) + is_host_in_room = yield self._is_host_in_room(current_state_ids) if effective_membership_state == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(current_state): + if requester.is_guest and not self._can_guest_join(current_state_ids): # This should be an auth check, but guests are a local concept, # so don't really fit into the general auth process. raise AuthError(403, "Guest access not allowed") @@ -326,15 +341,17 @@ class RoomMemberHandler(BaseHandler): requester = synapse.types.create_requester(target_user) message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) + prev_event = yield message_handler.deduplicate_state_event(event, context) if prev_event is not None: return if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") + if requester.is_guest: + guest_can_join = yield self._can_guest_join(context.prev_state_ids) + if not guest_can_join: + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") yield message_handler.handle_new_client_event( requester, @@ -344,27 +361,39 @@ class RoomMemberHandler(BaseHandler): ratelimit=ratelimit, ) - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), + prev_member_event_id = context.prev_state_ids.get( + (EventTypes.Member, event.state_key), None ) if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: yield user_joined_room(self.distributor, target_user, room_id) elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) - def _can_guest_join(self, current_state): + @defer.inlineCallbacks + def _can_guest_join(self, current_state_ids): """ Returns whether a guest can join a room based on its current state. """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( + guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + if not guest_access_id: + defer.returnValue(False) + + guest_access = yield self.store.get_event(guest_access_id) + + defer.returnValue( guest_access and guest_access.content and "guest_access" in guest_access.content @@ -683,3 +712,24 @@ class RoomMemberHandler(BaseHandler): if membership: yield self.store.forget(user_id, room_id) + + @defer.inlineCallbacks + def _is_host_in_room(self, current_state_ids): + # Have we just created the room, and is this about to be the very + # first member event? + create_event_id = current_state_ids.get(("m.room.create", "")) + if len(current_state_ids) == 1 and create_event_id: + defer.returnValue(self.hs.is_mine_id(create_event_id)) + + for (etype, state_key), event_id in current_state_ids.items(): + if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if not event: + continue + + if event.membership == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8dfd02e7b..b5962f4f5a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [ "filter_collection", "is_guest", "request_key", + "device_id", ]) @@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. ])): __slots__ = [] @@ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.joined or self.invited or self.archived or - self.account_data + self.account_data or + self.to_device ) @@ -139,6 +142,7 @@ class SyncHandler(object): self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) + self.state = hs.get_state_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -355,11 +359,11 @@ class SyncHandler(object): Returns: A Deferred map from ((type, state_key)->Event) """ - state = yield self.store.get_state_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event(event.event_id) if event.is_state(): - state = state.copy() - state[(event.type, event.state_key)] = event - defer.returnValue(state) + state_ids = state_ids.copy() + state_ids[(event.type, event.state_key)] = event.event_id + defer.returnValue(state_ids) @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): @@ -412,57 +416,61 @@ class SyncHandler(object): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state = yield self.store.get_state_for_event( + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) else: - current_state = yield self.get_state_at( + current_state_ids = yield self.get_state_at( room_id, stream_position=now_token ) - state = current_state + state_ids = current_state_ids timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, - timeline_start=state, + timeline_start=state_ids, previous={}, - current=current_state, + current=current_state_ids, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state_at_timeline_start = yield self.store.get_state_for_event( + state_at_timeline_start = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, - current=current_state, + current=current_state_ids, ) else: - state = {} + state_ids = {} + + state = {} + if state_ids: + state = yield self.store.get_events(state_ids.values()) defer.returnValue({ (e.type, e.state_key): e @@ -527,16 +535,58 @@ class SyncHandler(object): sync_result_builder, newly_joined_rooms, newly_joined_users ) + yield self._generate_sync_entry_for_to_device(sync_result_builder) + defer.returnValue(SyncResult( presence=sync_result_builder.presence, account_data=sync_result_builder.account_data, joined=sync_result_builder.joined, invited=sync_result_builder.invited, archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, next_batch=sync_result_builder.now_token, )) @defer.inlineCallbacks + def _generate_sync_entry_for_to_device(self, sync_result_builder): + """Generates the portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + user_id = sync_result_builder.sync_config.user.to_string() + device_id = sync_result_builder.sync_config.device_id + now_token = sync_result_builder.now_token + since_stream_id = 0 + if sync_result_builder.since_token is not None: + since_stream_id = int(sync_result_builder.since_token.to_device_key) + + if since_stream_id != int(now_token.to_device_key): + # We only delete messages when a new message comes in, but that's + # fine so long as we delete them at some point. + + logger.debug("Deleting messages up to %d", since_stream_id) + yield self.store.delete_messages_for_device( + user_id, device_id, since_stream_id + ) + + logger.debug("Getting messages up to %d", now_token.to_device_key) + messages, stream_id = yield self.store.get_new_messages_for_device( + user_id, device_id, since_stream_id, now_token.to_device_key + ) + logger.debug("Got messages up to %d: %r", stream_id, messages) + sync_result_builder.now_token = now_token.copy_and_replace( + "to_device_key", stream_id + ) + sync_result_builder.to_device = messages + else: + sync_result_builder.to_device = [] + + @defer.inlineCallbacks def _generate_sync_entry_for_account_data(self, sync_result_builder): """Generates the account data portion of the sync response. Populates `sync_result_builder` with the result. @@ -626,7 +676,7 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_users) for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) @@ -766,8 +816,13 @@ class SyncHandler(object): # the last sync (even if we have since left). This is to make sure # we do send down the room, and with full state, where necessary if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) @@ -1059,27 +1114,25 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): Returns: dict """ - event_id_to_state = { - e.event_id: e - for e in itertools.chain( - timeline_contains.values(), - previous.values(), - timeline_start.values(), - current.values(), + event_id_to_key = { + e: key + for key, e in itertools.chain( + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(e.event_id for e in current.values()) - tc_ids = set(e.event_id for e in timeline_contains.values()) - p_ids = set(e.event_id for e in previous.values()) - ts_ids = set(e.event_id for e in timeline_start.values()) + c_ids = set(e for e in current.values()) + tc_ids = set(e for e in timeline_contains.values()) + p_ids = set(e for e in previous.values()) + ts_ids = set(e for e in timeline_start.values()) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - evs = (event_id_to_state[e] for e in state_ids) return { - (e.type, e.state_key): e - for e in evs + event_id_to_key[e]: e for e in state_ids } @@ -1103,6 +1156,7 @@ class SyncResultBuilder(object): self.joined = [] self.invited = [] self.archived = [] + self.device = [] class RoomSyncResultBuilder(object): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 46181984c0..0b530b9034 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,7 +20,7 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, ) from synapse.util.metrics import Measure -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id import logging @@ -42,6 +42,7 @@ class TypingHandler(object): self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() self.clock = hs.get_clock() @@ -166,7 +167,8 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_update(self, room_id, user_id, typing): - domains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) deferreds = [] for domain in domains: @@ -199,7 +201,8 @@ class TypingHandler(object): # Check that the string is a valid user id UserID.from_string(user_id) - domains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: self._push_update_local( |