diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/__init__.py | 2 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 9 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 14 | ||||
-rw-r--r-- | synapse/handlers/events.py | 79 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 583 | ||||
-rw-r--r-- | synapse/handlers/message.py | 100 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 40 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 3 | ||||
-rw-r--r-- | synapse/handlers/register.py | 20 | ||||
-rw-r--r-- | synapse/handlers/room.py | 52 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 439 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 21 |
12 files changed, 1115 insertions, 247 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index fe071a4bc2..a32eab9316 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -26,6 +26,7 @@ from .presence import PresenceHandler from .directory import DirectoryHandler from .typing import TypingNotificationHandler from .admin import AdminHandler +from .sync import SyncHandler class Handlers(object): @@ -51,3 +52,4 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) + self.sync_handler = SyncHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 38af034b4d..1773fa20aa 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -19,6 +19,7 @@ from synapse.api.errors import LimitExceededError, SynapseError from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes +from synapse.types import UserID import logging @@ -113,7 +114,7 @@ class BaseHandler(object): if event.type == EventTypes.Member: if event.content["membership"] == Membership.INVITE: - invitee = self.hs.parse_userid(event.state_key) + invitee = UserID.from_string(event.state_key) if not self.hs.is_mine(invitee): # TODO: Can we add signature from remote server in a nicer # way? If we have been invited by a remote server, we need @@ -134,7 +135,7 @@ class BaseHandler(object): if k[0] == EventTypes.Member: if s.content["membership"] == Membership.JOIN: destinations.add( - self.hs.parse_userid(s.state_key).domain + UserID.from_string(s.state_key).domain ) except SynapseError: logger.warn( @@ -144,7 +145,5 @@ class BaseHandler(object): yield self.notifier.on_new_room_event(event, extra_users=extra_users) yield federation_handler.handle_new_event( - event, - None, - destinations=destinations, + event, destinations=destinations, ) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 91fceda2ac..7b60921040 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,6 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.constants import EventTypes +from synapse.types import RoomAlias import logging @@ -112,7 +113,16 @@ class DirectoryHandler(BaseHandler): ) extra_servers = yield self.store.get_joined_hosts_for_room(room_id) - servers = list(set(extra_servers) | set(servers)) + servers = set(extra_servers) | set(servers) + + # If this server is in the list of servers, return it first. + if self.server_name in servers: + servers = ( + [self.server_name] + + [s for s in servers if s != self.server_name] + ) + else: + servers = list(servers) defer.returnValue({ "room_id": room_id, @@ -122,7 +132,7 @@ class DirectoryHandler(BaseHandler): @defer.inlineCallbacks def on_directory_query(self, args): - room_alias = self.hs.parse_roomalias(args["room_alias"]) + room_alias = RoomAlias.from_string(args["room_alias"]) if not self.hs.is_mine(room_alias): raise SynapseError( 400, "Room Alias is not hosted on this Home Server" diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 808219bd10..025e7e7e62 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -17,6 +17,8 @@ from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logutils import log_function +from synapse.types import UserID +from synapse.events.utils import serialize_event from ._base import BaseHandler @@ -46,39 +48,43 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function - def get_stream(self, auth_user_id, pagin_config, timeout=0): - auth_user = self.hs.parse_userid(auth_user_id) + def get_stream(self, auth_user_id, pagin_config, timeout=0, + as_client_event=True, affect_presence=True): + auth_user = UserID.from_string(auth_user_id) try: - if auth_user not in self._streams_per_user: - self._streams_per_user[auth_user] = 0 - if auth_user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(auth_user) + if affect_presence: + if auth_user not in self._streams_per_user: + self._streams_per_user[auth_user] = 0 + if auth_user in self._stop_timer_per_user: + try: + self.clock.cancel_call_later( + self._stop_timer_per_user.pop(auth_user) + ) + except: + logger.exception("Failed to cancel event timer") + else: + yield self.distributor.fire( + "started_user_eventstream", auth_user ) - except: - logger.exception("Failed to cancel event timer") - else: - yield self.distributor.fire( - "started_user_eventstream", auth_user - ) - self._streams_per_user[auth_user] += 1 + self._streams_per_user[auth_user] += 1 if pagin_config.from_token is None: pagin_config.from_token = None rm_handler = self.hs.get_handlers().room_member_handler - logger.debug("BETA") room_ids = yield rm_handler.get_rooms_for_user(auth_user) - logger.debug("ALPHA") with PreserveLoggingContext(): events, tokens = yield self.notifier.get_events_for( auth_user, room_ids, pagin_config, timeout ) - chunks = [self.hs.serialize_event(e) for e in events] + time_now = self.clock.time_msec() + + chunks = [ + serialize_event(e, time_now, as_client_event) for e in events + ] chunk = { "chunk": chunks, @@ -89,27 +95,28 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) finally: - self._streams_per_user[auth_user] -= 1 - if not self._streams_per_user[auth_user]: - del self._streams_per_user[auth_user] - - # 10 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug( - "_later stopped_user_eventstream %s", auth_user - ) + if affect_presence: + self._streams_per_user[auth_user] -= 1 + if not self._streams_per_user[auth_user]: + del self._streams_per_user[auth_user] + + # 10 seconds of grace to allow the client to reconnect again + # before we think they're gone + def _later(): + logger.debug( + "_later stopped_user_eventstream %s", auth_user + ) - self._stop_timer_per_user.pop(auth_user, None) + self._stop_timer_per_user.pop(auth_user, None) - yield self.distributor.fire( - "stopped_user_eventstream", auth_user - ) + return self.distributor.fire( + "stopped_user_eventstream", auth_user + ) - logger.debug("Scheduling _later: for %s", auth_user) - self._stop_timer_per_user[auth_user] = ( - self.clock.call_later(30, _later) - ) + logger.debug("Scheduling _later: for %s", auth_user) + self._stop_timer_per_user[auth_user] = ( + self.clock.call_later(30, _later) + ) class EventHandler(BaseHandler): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d26975a88a..77c81fe2da 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,21 +17,21 @@ from ._base import BaseHandler -from synapse.events.utils import prune_event from synapse.api.errors import ( - AuthError, FederationError, SynapseError, StoreError, + AuthError, FederationError, StoreError, ) -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor +from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( - compute_event_signature, check_event_content_hash, - add_hashes_and_signatures, + compute_event_signature, add_hashes_and_signatures, ) -from syutil.jsonutil import encode_canonical_json +from synapse.types import UserID from twisted.internet import defer +import itertools import logging @@ -75,14 +75,14 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def handle_new_event(self, event, snapshot, destinations): + def handle_new_event(self, event, destinations): """ Takes in an event from the client to server side, that has already been authed and handled by the state module, and sends it to any remote home servers that may be interested. Args: - event - snapshot (.storage.Snapshot): THe snapshot the event happened after + event: The event to send + destinations: A list of destinations to send it to Returns: Deferred: Resolved when it has successfully been queued for @@ -112,33 +112,6 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event.event_id) - redacted_event = prune_event(event) - - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - event.origin, redacted_pdu_json - ) - except SynapseError as e: - logger.warn( - "Signature check failed for %s redacted to %s", - encode_canonical_json(pdu.get_pdu_json()), - encode_canonical_json(redacted_pdu_json), - ) - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) - - if not check_event_content_hash(event): - logger.warn( - "Event content has been tampered, redacting %s, %s", - event.event_id, encode_canonical_json(event.get_dict()) - ) - event = redacted_event - logger.debug("Event: %s", event) # FIXME (erikj): Awful hack to make the case where we are not currently @@ -148,41 +121,38 @@ class FederationHandler(BaseHandler): event.room_id, self.server_name ) - if not is_in_room and not event.internal_metadata.outlier: + if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") + current_state = state - replication = self.replication_layer + event_ids = set() + if state: + event_ids |= {e.event_id for e in state} + if auth_chain: + event_ids |= {e.event_id for e in auth_chain} - if not state: - state, auth_chain = yield replication.get_state_for_context( - origin, context=event.room_id, event_id=event.event_id, - ) + seen_ids = set( + (yield self.store.have_events(event_ids)).keys() + ) - if not auth_chain: - auth_chain = yield replication.get_event_auth( - origin, - context=event.room_id, - event_id=event.event_id, - ) + if state and auth_chain is not None: + # If we have any state or auth_chain given to us by the replication + # layer, then we should handle them (if we haven't before.) + for e in itertools.chain(auth_chain, state): + if e.event_id in seen_ids: + continue - for e in auth_chain: e.internal_metadata.outlier = True try: - yield self._handle_new_event(e, fetch_auth_from=origin) - except: - logger.exception( - "Failed to handle auth event %s", - e.event_id, + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + origin, e, auth_events=auth ) - - current_state = state - - if state: - for e in state: - logging.info("A :) %r", e) - e.internal_metadata.outlier = True - try: - yield self._handle_new_event(e) + seen_ids.add(e.event_id) except: logger.exception( "Failed to handle state event %s", @@ -191,6 +161,7 @@ class FederationHandler(BaseHandler): try: yield self._handle_new_event( + origin, event, state=state, backfilled=backfilled, @@ -227,7 +198,7 @@ class FederationHandler(BaseHandler): extra_users = [] if event.type == EventTypes.Member: target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) + target_user = UserID.from_string(target_user_id) extra_users.append(target_user) yield self.notifier.on_new_room_event( @@ -236,7 +207,7 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - user = self.hs.parse_userid(event.state_key) + user = UserID.from_string(event.state_key) yield self.distributor.fire( "user_joined_room", user=user, room_id=event.room_id ) @@ -281,7 +252,7 @@ class FederationHandler(BaseHandler): """ pdu = yield self.replication_layer.send_invite( destination=target_host, - context=event.room_id, + room_id=event.room_id, event_id=event.event_id, pdu=event ) @@ -305,7 +276,7 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def do_invite_join(self, target_host, room_id, joinee, content, snapshot): + def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot): """ Attempts to join the `joinee` to the room `room_id` via the server `target_host`. @@ -319,8 +290,8 @@ class FederationHandler(BaseHandler): """ logger.debug("Joining %s to %s", joinee, room_id) - pdu = yield self.replication_layer.make_join( - target_host, + origin, pdu = yield self.replication_layer.make_join( + target_hosts, room_id, joinee ) @@ -341,7 +312,7 @@ class FederationHandler(BaseHandler): self.room_queues[room_id] = [] builder = self.event_builder_factory.new( - event.get_pdu_json() + unfreeze(event.get_pdu_json()) ) handled_events = set() @@ -362,11 +333,20 @@ class FederationHandler(BaseHandler): new_event = builder.build() + # Try the host we successfully got a response to /make_join/ + # request first. + try: + target_hosts.remove(origin) + target_hosts.insert(0, origin) + except ValueError: + pass + ret = yield self.replication_layer.send_join( - target_host, + target_hosts, new_event ) + origin = ret["origin"] state = ret["state"] auth_chain = ret["auth_chain"] auth_chain.sort(key=lambda e: e.depth) @@ -392,8 +372,19 @@ class FederationHandler(BaseHandler): for e in auth_chain: e.internal_metadata.outlier = True + + if e.event_id == event.event_id: + continue + try: - yield self._handle_new_event(e) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + origin, e, auth_events=auth + ) except: logger.exception( "Failed to handle auth event %s", @@ -401,11 +392,18 @@ class FederationHandler(BaseHandler): ) for e in state: - # FIXME: Auth these. + if e.event_id == event.event_id: + continue + e.internal_metadata.outlier = True try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } yield self._handle_new_event( - e, fetch_auth_from=target_host + origin, e, auth_events=auth ) except: logger.exception( @@ -413,10 +411,18 @@ class FederationHandler(BaseHandler): e.event_id, ) + auth_ids = [e_id for e_id, _ in event.auth_events] + auth_events = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + origin, new_event, state=state, current_state=state, + auth_events=auth_events, ) yield self.notifier.on_new_room_event( @@ -480,7 +486,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event(event) + context = yield self._handle_new_event(origin, event) logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", @@ -491,7 +497,7 @@ class FederationHandler(BaseHandler): extra_users = [] if event.type == EventTypes.Member: target_user_id = event.state_key - target_user = self.hs.parse_userid(target_user_id) + target_user = UserID.from_string(target_user_id) extra_users.append(target_user) yield self.notifier.on_new_room_event( @@ -500,7 +506,7 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.Member: if event.content["membership"] == Membership.JOIN: - user = self.hs.parse_userid(event.state_key) + user = UserID.from_string(event.state_key) yield self.distributor.fire( "user_joined_room", user=user, room_id=event.room_id ) @@ -514,13 +520,15 @@ class FederationHandler(BaseHandler): if k[0] == EventTypes.Member: if s.content["membership"] == Membership.JOIN: destinations.add( - self.hs.parse_userid(s.state_key).domain + UserID.from_string(s.state_key).domain ) except: logger.warn( "Failed to get destination from event %s", s.event_id ) + destinations.discard(origin) + logger.debug( "on_send_join_request: Sending event: %s, signatures: %s", event.event_id, @@ -565,7 +573,7 @@ class FederationHandler(BaseHandler): backfilled=False, ) - target_user = self.hs.parse_userid(event.state_key) + target_user = UserID.from_string(event.state_key) yield self.notifier.on_new_room_event( event, extra_users=[target_user], ) @@ -617,13 +625,13 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function - def on_backfill_request(self, origin, context, pdu_list, limit): - in_room = yield self.auth.check_host_in_room(context, origin) + def on_backfill_request(self, origin, room_id, pdu_list, limit): + in_room = yield self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") events = yield self.store.get_backfill_events( - context, + room_id, pdu_list, limit ) @@ -641,6 +649,7 @@ class FederationHandler(BaseHandler): event = yield self.store.get_event( event_id, allow_none=True, + allow_rejected=True, ) if event: @@ -681,11 +690,12 @@ class FederationHandler(BaseHandler): waiters.pop().callback(None) @defer.inlineCallbacks - def _handle_new_event(self, event, state=None, backfilled=False, - current_state=None, fetch_auth_from=None): + @log_function + def _handle_new_event(self, origin, event, state=None, backfilled=False, + current_state=None, auth_events=None): logger.debug( - "_handle_new_event: Before annotate: %s, sigs: %s", + "_handle_new_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -693,65 +703,46 @@ class FederationHandler(BaseHandler): event, old_state=state ) + if not auth_events: + auth_events = context.auth_events + logger.debug( - "_handle_new_event: Before auth fetch: %s, sigs: %s", - event.event_id, event.signatures, + "_handle_new_event: %s, auth_events: %s", + event.event_id, auth_events, ) is_new_state = not event.internal_metadata.is_outlier() - known_ids = set( - [s.event_id for s in context.auth_events.values()] - ) - - for e_id, _ in event.auth_events: - if e_id not in known_ids: - e = yield self.store.get_event(e_id, allow_none=True) - - if not e and fetch_auth_from is not None: - # Grab the auth_chain over federation if we are missing - # auth events. - auth_chain = yield self.replication_layer.get_event_auth( - fetch_auth_from, event.event_id, event.room_id - ) - for auth_event in auth_chain: - yield self._handle_new_event(auth_event) - e = yield self.store.get_event(e_id, allow_none=True) - - if not e: - # TODO: Do some conflict res to make sure that we're - # not the ones who are wrong. - logger.info( - "Rejecting %s as %s not in db or %s", - event.event_id, e_id, known_ids, - ) - # FIXME: How does raising AuthError work with federation? - raise AuthError(403, "Cannot find auth event") - - context.auth_events[(e.type, e.state_key)] = e - - logger.debug( - "_handle_new_event: Before hack: %s, sigs: %s", - event.event_id, event.signatures, - ) - + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_events: if len(event.prev_events) == 1: c = yield self.store.get_event(event.prev_events[0][0]) if c.type == EventTypes.Create: - context.auth_events[(c.type, c.state_key)] = c + auth_events[(c.type, c.state_key)] = c - logger.debug( - "_handle_new_event: Before auth check: %s, sigs: %s", - event.event_id, event.signatures, - ) + try: + yield self.do_auth( + origin, event, context, auth_events=auth_events + ) + except AuthError as e: + logger.warn( + "Rejecting %s because %s", + event.event_id, e.msg + ) - self.auth.check(event, auth_events=context.auth_events) + context.rejected = RejectedReason.AUTH_ERROR - logger.debug( - "_handle_new_event: Before persist_event: %s, sigs: %s", - event.event_id, event.signatures, - ) + # FIXME: Don't store as rejected with AUTH_ERROR if we haven't + # seen all the auth events. + yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + is_new_state=False, + current_state=current_state, + ) + raise yield self.store.persist_event( event, @@ -761,9 +752,329 @@ class FederationHandler(BaseHandler): current_state=current_state, ) - logger.debug( - "_handle_new_event: After persist_event: %s, sigs: %s", - event.event_id, event.signatures, + defer.returnValue(context) + + @defer.inlineCallbacks + def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, + missing): + # Just go through and process each event in `remote_auth_chain`. We + # don't want to fall into the trap of `missing` being wrong. + for e in remote_auth_chain: + try: + yield self._handle_new_event(origin, e) + except AuthError: + pass + + # Now get the current auth_chain for the event. + local_auth_chain = yield self.store.get_auth_chain([event_id]) + + # TODO: Check if we would now reject event_id. If so we need to tell + # everyone. + + ret = yield self.construct_auth_difference( + local_auth_chain, remote_auth_chain ) - defer.returnValue(context) + for event in ret["auth_chain"]: + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) + ) + + logger.debug("on_query_auth returning: %s", ret) + + defer.returnValue(ret) + + @defer.inlineCallbacks + @log_function + def do_auth(self, origin, event, context, auth_events): + # Check if we have all the auth events. + have_events = yield self.store.have_events( + [e_id for e_id, _ in event.auth_events] + ) + + event_auth_events = set(e_id for e_id, _ in event.auth_events) + seen_events = set(have_events.keys()) + + missing_auth = event_auth_events - seen_events + + if missing_auth: + logger.debug("Missing auth: %s", missing_auth) + # If we don't have all the auth events, we need to get them. + try: + remote_auth_chain = yield self.replication_layer.get_event_auth( + origin, event.room_id, event.event_id + ) + + seen_remotes = yield self.store.have_events( + [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes.keys(): + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in remote_auth_chain + if e.event_id in auth_ids + } + e.internal_metadata.outlier = True + + logger.debug( + "do_auth %s missing_auth: %s", + event.event_id, e.event_id + ) + yield self._handle_new_event( + origin, e, auth_events=auth + ) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + have_events = yield self.store.have_events( + [e_id for e_id, _ in event.auth_events] + ) + seen_events = set(have_events.keys()) + except: + # FIXME: + logger.exception("Failed to get auth chain") + + # FIXME: Assumes we have and stored all the state for all the + # prev_events + current_state = set(e.event_id for e in auth_events.values()) + different_auth = event_auth_events - current_state + + if different_auth and not event.internal_metadata.is_outlier(): + # Do auth conflict res. + logger.debug("Different auth: %s", different_auth) + + # Only do auth resolution if we have something new to say. + # We can't rove an auth failure. + do_resolution = False + + provable = [ + RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR, + ] + + for e_id in different_auth: + if e_id in have_events: + if have_events[e_id] in provable: + do_resolution = True + break + + if do_resolution: + # 1. Get what we think is the auth chain. + auth_ids = self.auth.compute_auth_events( + event, context.current_state + ) + local_auth_chain = yield self.store.get_auth_chain(auth_ids) + + try: + # 2. Get remote difference. + result = yield self.replication_layer.query_auth( + origin, + event.room_id, + event.event_id, + local_auth_chain, + ) + + seen_remotes = yield self.store.have_events( + [e.event_id for e in result["auth_chain"]] + ) + + # 3. Process any remote auth chain events we haven't seen. + for ev in result["auth_chain"]: + if ev.event_id in seen_remotes.keys(): + continue + + if ev.event_id == event.event_id: + continue + + try: + auth_ids = [e_id for e_id, _ in ev.auth_events] + auth = { + (e.type, e.state_key): e + for e in result["auth_chain"] + if e.event_id in auth_ids + } + ev.internal_metadata.outlier = True + + logger.debug( + "do_auth %s different_auth: %s", + event.event_id, e.event_id + ) + + yield self._handle_new_event( + origin, ev, auth_events=auth + ) + + if ev.event_id in event_auth_events: + auth_events[(ev.type, ev.state_key)] = ev + except AuthError: + pass + + except: + # FIXME: + logger.exception("Failed to query auth chain") + + # 4. Look at rejects and their proofs. + # TODO. + + context.current_state.update(auth_events) + context.state_group = None + + try: + self.auth.check(event, auth_events=auth_events) + except AuthError: + raise + + @defer.inlineCallbacks + def construct_auth_difference(self, local_auth, remote_auth): + """ Given a local and remote auth chain, find the differences. This + assumes that we have already processed all events in remote_auth + + Params: + local_auth (list) + remote_auth (list) + + Returns: + dict + """ + + logger.debug("construct_auth_difference Start!") + + # TODO: Make sure we are OK with local_auth or remote_auth having more + # auth events in them than strictly necessary. + + def sort_fun(ev): + return ev.depth, ev.event_id + + logger.debug("construct_auth_difference after sort_fun!") + + # We find the differences by starting at the "bottom" of each list + # and iterating up on both lists. The lists are ordered by depth and + # then event_id, we iterate up both lists until we find the event ids + # don't match. Then we look at depth/event_id to see which side is + # missing that event, and iterate only up that list. Repeat. + + remote_list = list(remote_auth) + remote_list.sort(key=sort_fun) + + local_list = list(local_auth) + local_list.sort(key=sort_fun) + + local_iter = iter(local_list) + remote_iter = iter(remote_list) + + logger.debug("construct_auth_difference before get_next!") + + def get_next(it, opt=None): + try: + return it.next() + except: + return opt + + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + + logger.debug("construct_auth_difference before while") + + missing_remotes = [] + missing_locals = [] + while current_local or current_remote: + if current_remote is None: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local is None: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + if current_local.event_id == current_remote.event_id: + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + continue + + if current_local.depth < current_remote.depth: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local.depth > current_remote.depth: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + # They have the same depth, so we fall back to the event_id order + if current_local.event_id < current_remote.event_id: + missing_locals.append(current_local) + current_local = get_next(local_iter) + + if current_local.event_id > current_remote.event_id: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + logger.debug("construct_auth_difference after while") + + # missing locals should be sent to the server + # We should find why we are missing remotes, as they will have been + # rejected. + + # Remove events from missing_remotes if they are referencing a missing + # remote. We only care about the "root" rejected ones. + missing_remote_ids = [e.event_id for e in missing_remotes] + base_remote_rejected = list(missing_remotes) + for e in missing_remotes: + for e_id, _ in e.auth_events: + if e_id in missing_remote_ids: + try: + base_remote_rejected.remove(e) + except ValueError: + pass + + reason_map = {} + + for e in base_remote_rejected: + reason = yield self.store.get_rejection_reason(e.event_id) + if reason is None: + # TODO: e is not in the current state, so we should + # construct some proof of that. + continue + + reason_map[e.event_id] = reason + + if reason == RejectedReason.AUTH_ERROR: + pass + elif reason == RejectedReason.REPLACED: + # TODO: Get proof + pass + elif reason == RejectedReason.NOT_ANCESTOR: + # TODO: Get proof. + pass + + logger.debug("construct_auth_difference returning") + + defer.returnValue({ + "auth_chain": local_auth, + "rejects": { + e.event_id: { + "reason": reason_map[e.event_id], + "proof": None, + } + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + }) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7195de98b5..3355adefcf 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,10 +16,12 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import RoomError +from synapse.api.errors import RoomError, SynapseError from synapse.streams.config import PaginationConfig +from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import UserID from ._base import BaseHandler @@ -33,6 +35,7 @@ class MessageHandler(BaseHandler): def __init__(self, hs): super(MessageHandler, self).__init__(hs) self.hs = hs + self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -67,7 +70,7 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def get_messages(self, user_id=None, room_id=None, pagin_config=None, - feedback=False): + feedback=False, as_client_event=True): """Get messages in a room. Args: @@ -76,6 +79,7 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. feedback (bool): True to get compressed feedback with the messages + as_client_event (bool): True to get events in client-server format. Returns: dict: Pagination API results """ @@ -88,7 +92,7 @@ class MessageHandler(BaseHandler): yield self.hs.get_event_sources().get_current_token() ) - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) events, next_key = yield data_source.get_pagination_rows( user, pagin_config.get_source_config("room"), room_id @@ -98,8 +102,12 @@ class MessageHandler(BaseHandler): "room_key", next_key ) + time_now = self.clock.time_msec() + chunk = { - "chunk": [self.hs.serialize_event(e) for e in events], + "chunk": [ + serialize_event(e, time_now, as_client_event) for e in events + ], "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), } @@ -107,7 +115,8 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True): + def create_and_send_event(self, event_dict, ratelimit=True, + client=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -127,13 +136,13 @@ class MessageHandler(BaseHandler): if ratelimit: self.ratelimit(builder.user_id) # TODO(paul): Why does 'event' not have a 'user' object? - user = self.hs.parse_userid(builder.user_id) + user = UserID.from_string(builder.user_id) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if builder.type == EventTypes.Member: membership = builder.content.get("membership", None) if membership == Membership.JOIN: - joinee = self.hs.parse_userid(builder.state_key) + joinee = UserID.from_string(builder.state_key) # If event doesn't include a display name, add one. yield self.distributor.fire( "collect_presencelike_data", @@ -141,6 +150,15 @@ class MessageHandler(BaseHandler): builder.content ) + if client is not None: + if client.token_id is not None: + builder.internal_metadata.token_id = client.token_id + if client.device_id is not None: + builder.internal_metadata.device_id = client.device_id + + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id + event, context = yield self._create_new_client_event( builder=builder, ) @@ -207,11 +225,14 @@ class MessageHandler(BaseHandler): # TODO: This is duplicating logic from snapshot_all_rooms current_state = yield self.state_handler.get_current_state(room_id) - defer.returnValue([self.hs.serialize_event(c) for c in current_state]) + now = self.clock.time_msec() + defer.returnValue( + [serialize_event(c, now) for c in current_state.values()] + ) @defer.inlineCallbacks def snapshot_all_rooms(self, user_id=None, pagin_config=None, - feedback=False): + feedback=False, as_client_event=True): """Retrieve a snapshot of all rooms the user is invited or has joined. This snapshot may include messages for all rooms where the user is @@ -222,6 +243,7 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config used to determine how many messages *PER ROOM* to return. feedback (bool): True to get feedback along with these messages. + as_client_event (bool): True to get events in client-server format. Returns: A list of dicts with "room_id" and "membership" keys for all rooms the user is currently invited or joined in on. Rooms where the user @@ -233,7 +255,7 @@ class MessageHandler(BaseHandler): membership_list=[Membership.INVITE, Membership.JOIN] ) - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) rooms_ret = [] @@ -278,9 +300,13 @@ class MessageHandler(BaseHandler): start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() d["messages"] = { - "chunk": [self.hs.serialize_event(m) for m in messages], + "chunk": [ + serialize_event(m, time_now, as_client_event) + for m in messages + ], "start": start_token.to_string(), "end": end_token.to_string(), } @@ -289,7 +315,8 @@ class MessageHandler(BaseHandler): event.room_id ) d["state"] = [ - self.hs.serialize_event(c) for c in current_state + serialize_event(c, time_now, as_client_event) + for c in current_state.values() ] except: logger.exception("Failed to get snapshot") @@ -305,20 +332,27 @@ class MessageHandler(BaseHandler): @defer.inlineCallbacks def room_initial_sync(self, user_id, room_id, pagin_config=None, feedback=False): - yield self.auth.check_joined_room(room_id, user_id) + current_state = yield self.state.get_current_state( + room_id=room_id, + ) + + yield self.auth.check_joined_room( + room_id, user_id, + current_state=current_state + ) # TODO(paul): I wish I was called with user objects not user_id # strings... - auth_user = self.hs.parse_userid(user_id) + auth_user = UserID.from_string(user_id) # TODO: These concurrently - state_tuples = yield self.state_handler.get_current_state(room_id) - state = [self.hs.serialize_event(x) for x in state_tuples] + time_now = self.clock.time_msec() + state = [ + serialize_event(x, time_now) + for x in current_state.values() + ] - member_event = (yield self.store.get_room_member( - user_id=user_id, - room_id=room_id - )) + member_event = current_state.get((EventTypes.Member, user_id,)) now_token = yield self.hs.get_event_sources().get_current_token() @@ -335,28 +369,40 @@ class MessageHandler(BaseHandler): start_token = now_token.copy_and_replace("room_key", token[0]) end_token = now_token.copy_and_replace("room_key", token[1]) - room_members = yield self.store.get_room_members(room_id) + room_members = [ + m for m in current_state.values() + if m.type == EventTypes.Member + ] presence_handler = self.hs.get_handlers().presence_handler presence = [] for m in room_members: try: member_presence = yield presence_handler.get_state( - target_user=self.hs.parse_userid(m.user_id), + target_user=UserID.from_string(m.user_id), auth_user=auth_user, as_event=True, ) presence.append(member_presence) - except Exception: - logger.exception( - "Failed to get member presence of %r", m.user_id - ) + except SynapseError as e: + if e.code == 404: + # FIXME: We are doing this as a warn since this gets hit a + # lot and spams the logs. Why is this happening? + logger.warn( + "Failed to get member presence of %r", m.user_id + ) + else: + logger.exception( + "Failed to get member presence of %r", m.user_id + ) + + time_now = self.clock.time_msec() defer.returnValue({ "membership": member_event.membership, "room_id": room_id, "messages": { - "chunk": [self.hs.serialize_event(m) for m in messages], + "chunk": [serialize_event(m, time_now) for m in messages], "start": start_token.to_string(), "end": end_token.to_string(), }, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 8aeed99274..59287010ed 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -20,6 +20,7 @@ from synapse.api.constants import PresenceState from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import UserID from ._base import BaseHandler @@ -86,6 +87,10 @@ class PresenceHandler(BaseHandler): "changed_presencelike_data", self.changed_presencelike_data ) + # outbound signal from the presence module to advertise when a user's + # presence has changed + distributor.declare("user_presence_changed") + self.distributor = distributor self.federation = hs.get_replication_layer() @@ -96,22 +101,22 @@ class PresenceHandler(BaseHandler): self.federation.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( - observed_user=hs.parse_userid(content["observed_user"]), - observer_user=hs.parse_userid(content["observer_user"]), + observed_user=UserID.from_string(content["observed_user"]), + observer_user=UserID.from_string(content["observer_user"]), ) ) self.federation.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( - observed_user=hs.parse_userid(content["observed_user"]), - observer_user=hs.parse_userid(content["observer_user"]), + observed_user=UserID.from_string(content["observed_user"]), + observer_user=UserID.from_string(content["observer_user"]), ) ) self.federation.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( - observed_user=hs.parse_userid(content["observed_user"]), - observer_user=hs.parse_userid(content["observer_user"]), + observed_user=UserID.from_string(content["observed_user"]), + observer_user=UserID.from_string(content["observer_user"]), ) ) @@ -418,7 +423,7 @@ class PresenceHandler(BaseHandler): ) for p in presence: - observed_user = self.hs.parse_userid(p.pop("observed_user_id")) + observed_user = UserID.from_string(p.pop("observed_user_id")) p["observed_user"] = observed_user p.update(self._get_or_offline_usercache(observed_user).get_state()) if "last_active" in p: @@ -441,7 +446,7 @@ class PresenceHandler(BaseHandler): user.localpart, accepted=True ) target_users = set([ - self.hs.parse_userid(x["observed_user_id"]) for x in presence + UserID.from_string(x["observed_user_id"]) for x in presence ]) # Also include people in all my rooms @@ -452,9 +457,9 @@ class PresenceHandler(BaseHandler): if state is None: state = yield self.store.get_presence_state(user.localpart) else: -# statuscache = self._get_or_make_usercache(user) -# self._user_cachemap_latest_serial += 1 -# statuscache.update(state, self._user_cachemap_latest_serial) + # statuscache = self._get_or_make_usercache(user) + # self._user_cachemap_latest_serial += 1 + # statuscache.update(state, self._user_cachemap_latest_serial) pass yield self.push_update_to_local_and_remote( @@ -603,6 +608,7 @@ class PresenceHandler(BaseHandler): room_ids=room_ids, statuscache=statuscache, ) + yield self.distributor.fire("user_presence_changed", user, statuscache) @defer.inlineCallbacks def _push_presence_remote(self, user, destination, state=None): @@ -646,13 +652,15 @@ class PresenceHandler(BaseHandler): deferreds = [] for push in content.get("push", []): - user = self.hs.parse_userid(push["user_id"]) + user = UserID.from_string(push["user_id"]) logger.debug("Incoming presence update from %s", user) observers = set(self._remote_recvmap.get(user, set())) if observers: - logger.debug(" | %d interested local observers %r", len(observers), observers) + logger.debug( + " | %d interested local observers %r", len(observers), observers + ) rm_handler = self.homeserver.get_handlers().room_member_handler room_ids = yield rm_handler.get_rooms_for_user(user) @@ -694,14 +702,14 @@ class PresenceHandler(BaseHandler): del self._user_cachemap[user] for poll in content.get("poll", []): - user = self.hs.parse_userid(poll) + user = UserID.from_string(poll) if not self.hs.is_mine(user): continue # TODO(paul) permissions checks - if not user in self._remote_sendmap: + if user not in self._remote_sendmap: self._remote_sendmap[user] = set() self._remote_sendmap[user].add(origin) @@ -709,7 +717,7 @@ class PresenceHandler(BaseHandler): deferreds.append(self._push_presence_remote(user, origin)) for unpoll in content.get("unpoll", []): - user = self.hs.parse_userid(unpoll) + user = UserID.from_string(unpoll) if not self.hs.is_mine(user): continue diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 7777d3cc94..03b2159c53 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.constants import EventTypes, Membership from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import UserID from ._base import BaseHandler @@ -169,7 +170,7 @@ class ProfileHandler(BaseHandler): @defer.inlineCallbacks def on_profile_query(self, args): - user = self.hs.parse_userid(args["user_id"]) + user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): raise SynapseError(400, "User is not hosted on this Home Server") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 66a89c10b2..0247327eb9 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -99,6 +99,26 @@ class RegistrationHandler(BaseHandler): raise RegistrationError( 500, "Cannot generate user ID.") + # create a default avatar for the user + # XXX: ideally clients would explicitly specify one, but given they don't + # and we want consistent and pretty identicons for random users, we'll + # do it here. + try: + auth_user = UserID.from_string(user_id) + media_repository = self.hs.get_resource_for_media_repository() + identicon_resource = media_repository.getChildWithDefault("identicon", None) + upload_resource = media_repository.getChildWithDefault("upload", None) + identicon_bytes = identicon_resource.generate_identicon(user_id, 320, 320) + content_uri = yield upload_resource.create_content( + "image/png", None, identicon_bytes, len(identicon_bytes), auth_user + ) + profile_handler = self.hs.get_handlers().profile_handler + profile_handler.set_avatar_url( + auth_user, auth_user, ("%s#auto" % (content_uri,)) + ) + except NotImplementedError: + pass # make tests pass without messing around creating default avatars + defer.returnValue((user_id, token)) @defer.inlineCallbacks diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 59719a1fae..914742d913 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -16,12 +16,14 @@ """Contains functions for performing events on rooms.""" from twisted.internet import defer +from ._base import BaseHandler + from synapse.types import UserID, RoomAlias, RoomID from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import StoreError, SynapseError from synapse.util import stringutils from synapse.util.async import run_on_reactor -from ._base import BaseHandler +from synapse.events.utils import serialize_event import logging @@ -64,7 +66,7 @@ class RoomCreationHandler(BaseHandler): invite_list = config.get("invite", []) for i in invite_list: try: - self.hs.parse_userid(i) + UserID.from_string(i) except: raise SynapseError(400, "Invalid user_id: %s" % (i,)) @@ -114,7 +116,7 @@ class RoomCreationHandler(BaseHandler): servers=[self.hs.hostname], ) - user = self.hs.parse_userid(user_id) + user = UserID.from_string(user_id) creation_events = self._create_events_for_new_room( user, room_id, is_public=is_public ) @@ -246,11 +248,9 @@ class RoomMemberHandler(BaseHandler): @defer.inlineCallbacks def get_room_members(self, room_id): - hs = self.hs - users = yield self.store.get_users_in_room(room_id) - defer.returnValue([hs.parse_userid(u) for u in users]) + defer.returnValue([UserID.from_string(u) for u in users]) @defer.inlineCallbacks def fetch_room_distributions_into(self, room_id, localusers=None, @@ -295,8 +295,9 @@ class RoomMemberHandler(BaseHandler): yield self.auth.check_joined_room(room_id, user_id) member_list = yield self.store.get_room_members(room_id=room_id) + time_now = self.clock.time_msec() event_list = [ - self.hs.serialize_event(entry) + serialize_event(entry, time_now) for entry in member_list ] chunk_data = { @@ -368,7 +369,7 @@ class RoomMemberHandler(BaseHandler): ) if prev_state and prev_state.membership == Membership.JOIN: - user = self.hs.parse_userid(event.user_id) + user = UserID.from_string(event.user_id) self.distributor.fire( "user_left_room", user=user, room_id=event.room_id ) @@ -388,8 +389,6 @@ class RoomMemberHandler(BaseHandler): if not hosts: raise SynapseError(404, "No known servers") - host = hosts[0] - # If event doesn't include a display name, add one. yield self.distributor.fire( "collect_presencelike_data", joinee, content @@ -406,13 +405,13 @@ class RoomMemberHandler(BaseHandler): }) event, context = yield self._create_new_client_event(builder) - yield self._do_join(event, context, room_host=host, do_auth=True) + yield self._do_join(event, context, room_hosts=hosts, do_auth=True) defer.returnValue({"room_id": room_id}) @defer.inlineCallbacks - def _do_join(self, event, context, room_host=None, do_auth=True): - joinee = self.hs.parse_userid(event.state_key) + def _do_join(self, event, context, room_hosts=None, do_auth=True): + joinee = UserID.from_string(event.state_key) # room_id = RoomID.from_string(event.room_id, self.hs) room_id = event.room_id @@ -425,10 +424,22 @@ class RoomMemberHandler(BaseHandler): event.room_id, self.hs.hostname ) + if not is_host_in_room: + # is *anyone* in the room? + room_member_keys = [ + v for (k, v) in context.current_state.keys() if ( + k == "m.room.member" + ) + ] + if len(room_member_keys) == 0: + # has the room been created so we can join it? + create_event = context.current_state.get(("m.room.create", "")) + if create_event: + is_host_in_room = True if is_host_in_room: should_do_dance = False - elif room_host: + elif room_hosts: # TODO: Shouldn't this be remote_room_host? should_do_dance = True else: # TODO(markjh): get prev_state from snapshot @@ -440,17 +451,18 @@ class RoomMemberHandler(BaseHandler): inviter = UserID.from_string(prev_state.user_id) should_do_dance = not self.hs.is_mine(inviter) - room_host = inviter.domain + room_hosts = [inviter.domain] else: - should_do_dance = False + # return the same error as join_room_alias does + raise SynapseError(404, "No known servers") if should_do_dance: handler = self.hs.get_handlers().federation_handler yield handler.do_invite_join( - room_host, + room_hosts, room_id, event.user_id, - event.get_dict()["content"], # FIXME To get a non-frozen dict + event.content, # FIXME To get a non-frozen dict context ) else: @@ -463,7 +475,7 @@ class RoomMemberHandler(BaseHandler): do_auth=do_auth, ) - user = self.hs.parse_userid(event.user_id) + user = UserID.from_string(event.user_id) yield self.distributor.fire( "user_joined_room", user=user, room_id=room_id ) @@ -513,7 +525,7 @@ class RoomMemberHandler(BaseHandler): do_auth): yield run_on_reactor() - target_user = self.hs.parse_userid(event.state_key) + target_user = UserID.from_string(event.state_key) yield self.handle_new_client_event( event, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py new file mode 100644 index 0000000000..7883bbd834 --- /dev/null +++ b/synapse/handlers/sync.py @@ -0,0 +1,439 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseHandler + +from synapse.streams.config import PaginationConfig +from synapse.api.constants import Membership, EventTypes + +from twisted.internet import defer + +import collections +import logging + +logger = logging.getLogger(__name__) + + +SyncConfig = collections.namedtuple("SyncConfig", [ + "user", + "client_info", + "limit", + "gap", + "sort", + "backfill", + "filter", +]) + + +class RoomSyncResult(collections.namedtuple("RoomSyncResult", [ + "room_id", + "limited", + "published", + "events", + "state", + "prev_batch", + "ephemeral", +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if room needs to be part of the sync result. + """ + return bool(self.events or self.state or self.ephemeral) + + +class SyncResult(collections.namedtuple("SyncResult", [ + "next_batch", # Token for the next sync + "private_user_data", # List of private events for the user. + "public_user_data", # List of public events for all users. + "rooms", # RoomSyncResult for each room. +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if the notifier needs to wait for more events when polling for + events. + """ + return bool( + self.private_user_data or self.public_user_data or self.rooms + ) + + +class SyncHandler(BaseHandler): + + def __init__(self, hs): + super(SyncHandler, self).__init__(hs) + self.event_sources = hs.get_event_sources() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0): + """Get the sync for a client if we have new data for it now. Otherwise + wait for new data to arrive on the server. If the timeout expires, then + return an empty sync result. + Returns: + A Deferred SyncResult. + """ + if timeout == 0 or since_token is None: + result = yield self.current_sync_for_user(sync_config, since_token) + defer.returnValue(result) + else: + def current_sync_callback(): + return self.current_sync_for_user(sync_config, since_token) + + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + result = yield self.notifier.wait_for_events( + sync_config.user, room_ids, + sync_config.filter, timeout, current_sync_callback + ) + defer.returnValue(result) + + def current_sync_for_user(self, sync_config, since_token=None): + """Get the sync for client needed to match what the server has now. + Returns: + A Deferred SyncResult. + """ + if since_token is None: + return self.initial_sync(sync_config) + else: + if sync_config.gap: + return self.incremental_sync_with_gap(sync_config, since_token) + else: + # TODO(mjark): Handle gapless sync + raise NotImplementedError() + + @defer.inlineCallbacks + def initial_sync(self, sync_config): + """Get a sync for a client which is starting without any state + Returns: + A Deferred SyncResult. + """ + if sync_config.sort == "timeline,desc": + # TODO(mjark): Handle going through events in reverse order?. + # What does "most recent events" mean when applying the limits mean + # in this case? + raise NotImplementedError() + + now_token = yield self.event_sources.get_current_token() + + presence_stream = self.event_sources.sources["presence"] + # TODO (mjark): This looks wrong, shouldn't we be getting the presence + # UP to the present rather than after the present? + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user=sync_config.user, + pagination_config=pagination_config.get_source_config("presence"), + key=None + ) + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=sync_config.user.to_string(), + membership_list=[Membership.INVITE, Membership.JOIN] + ) + + # TODO (mjark): Does public mean "published"? + published_rooms = yield self.store.get_rooms(is_public=True) + published_room_ids = set(r["room_id"] for r in published_rooms) + + rooms = [] + for event in room_list: + room_sync = yield self.initial_sync_for_room( + event.room_id, sync_config, now_token, published_room_ids + ) + rooms.append(room_sync) + + defer.returnValue(SyncResult( + public_user_data=presence, + private_user_data=[], + rooms=rooms, + next_batch=now_token, + )) + + @defer.inlineCallbacks + def initial_sync_for_room(self, room_id, sync_config, now_token, + published_room_ids): + """Sync a room for a client which is starting without any state + Returns: + A Deferred RoomSyncResult. + """ + + recents, prev_batch_token, limited = yield self.load_filtered_recents( + room_id, sync_config, now_token, + ) + + current_state = yield self.state_handler.get_current_state( + room_id + ) + current_state_events = current_state.values() + + defer.returnValue(RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch_token, + state=current_state_events, + limited=limited, + ephemeral=[], + )) + + @defer.inlineCallbacks + def incremental_sync_with_gap(self, sync_config, since_token): + """ Get the incremental delta needed to bring the client up to + date with the server. + Returns: + A Deferred SyncResult. + """ + if sync_config.sort == "timeline,desc": + # TODO(mjark): Handle going through events in reverse order?. + # What does "most recent events" mean when applying the limits mean + # in this case? + raise NotImplementedError() + + now_token = yield self.event_sources.get_current_token() + + presence_source = self.event_sources.sources["presence"] + presence, presence_key = yield presence_source.get_new_events_for_user( + user=sync_config.user, + from_key=since_token.presence_key, + limit=sync_config.limit, + ) + now_token = now_token.copy_and_replace("presence_key", presence_key) + + typing_source = self.event_sources.sources["typing"] + typing, typing_key = yield typing_source.get_new_events_for_user( + user=sync_config.user, + from_key=since_token.typing_key, + limit=sync_config.limit, + ) + now_token = now_token.copy_and_replace("typing_key", typing_key) + + typing_by_room = {event["room_id"]: [event] for event in typing} + for event in typing: + event.pop("room_id") + logger.debug("Typing %r", typing_by_room) + + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + + # TODO (mjark): Does public mean "published"? + published_rooms = yield self.store.get_rooms(is_public=True) + published_room_ids = set(r["room_id"] for r in published_rooms) + + room_events, _ = yield self.store.get_room_events_stream( + sync_config.user.to_string(), + from_key=since_token.room_key, + to_key=now_token.room_key, + room_id=None, + limit=sync_config.limit + 1, + ) + + rooms = [] + if len(room_events) <= sync_config.limit: + # There is no gap in any of the rooms. Therefore we can just + # partition the new events by room and return them. + events_by_room_id = {} + for event in room_events: + events_by_room_id.setdefault(event.room_id, []).append(event) + + for room_id in room_ids: + recents = events_by_room_id.get(room_id, []) + state = [event for event in recents if event.is_state()] + if recents: + prev_batch = now_token.copy_and_replace( + "room_key", recents[0].internal_metadata.before + ) + else: + prev_batch = now_token + + state = yield self.check_joined_room( + sync_config, room_id, state + ) + + room_sync = RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch, + state=state, + limited=False, + ephemeral=typing_by_room.get(room_id, []) + ) + if room_sync: + rooms.append(room_sync) + else: + for room_id in room_ids: + room_sync = yield self.incremental_sync_with_gap_for_room( + room_id, sync_config, since_token, now_token, + published_room_ids, typing_by_room + ) + if room_sync: + rooms.append(room_sync) + + defer.returnValue(SyncResult( + public_user_data=presence, + private_user_data=[], + rooms=rooms, + next_batch=now_token, + )) + + @defer.inlineCallbacks + def load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None): + limited = True + recents = [] + filtering_factor = 2 + load_limit = max(sync_config.limit * filtering_factor, 100) + max_repeat = 3 # Only try a few times per room, otherwise + room_key = now_token.room_key + end_key = room_key + + while limited and len(recents) < sync_config.limit and max_repeat: + events, keys = yield self.store.get_recent_events_for_room( + room_id, + limit=load_limit + 1, + from_token=since_token.room_key if since_token else None, + end_token=end_key, + ) + (room_key, _) = keys + end_key = "s" + room_key.split('-')[-1] + loaded_recents = sync_config.filter.filter_room_events(events) + loaded_recents.extend(recents) + recents = loaded_recents + if len(events) <= load_limit: + limited = False + max_repeat -= 1 + + if len(recents) > sync_config.limit: + recents = recents[-sync_config.limit:] + room_key = recents[0].internal_metadata.before + + prev_batch_token = now_token.copy_and_replace( + "room_key", room_key + ) + + defer.returnValue((recents, prev_batch_token, limited)) + + @defer.inlineCallbacks + def incremental_sync_with_gap_for_room(self, room_id, sync_config, + since_token, now_token, + published_room_ids, typing_by_room): + """ Get the incremental delta needed to bring the client up to date for + the room. Gives the client the most recent events and the changes to + state. + Returns: + A Deferred RoomSyncResult + """ + + # TODO(mjark): Check for redactions we might have missed. + + recents, prev_batch_token, limited = yield self.load_filtered_recents( + room_id, sync_config, now_token, since_token, + ) + + logging.debug("Recents %r", recents) + + # TODO(mjark): This seems racy since this isn't being passed a + # token to indicate what point in the stream this is + current_state = yield self.state_handler.get_current_state( + room_id + ) + current_state_events = current_state.values() + + state_at_previous_sync = yield self.get_state_at_previous_sync( + room_id, since_token=since_token + ) + + state_events_delta = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=current_state_events, + ) + + state_events_delta = yield self.check_joined_room( + sync_config, room_id, state_events_delta + ) + + room_sync = RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch_token, + state=state_events_delta, + limited=limited, + ephemeral=typing_by_room.get(room_id, []) + ) + + logging.debug("Room sync: %r", room_sync) + + defer.returnValue(room_sync) + + @defer.inlineCallbacks + def get_state_at_previous_sync(self, room_id, since_token): + """ Get the room state at the previous sync the client made. + Returns: + A Deferred list of Events. + """ + last_events, token = yield self.store.get_recent_events_for_room( + room_id, end_token=since_token.room_key, limit=1, + ) + + if last_events: + last_event = last_events[0] + last_context = yield self.state_handler.compute_event_context( + last_event + ) + if last_event.is_state(): + state = [last_event] + last_context.current_state.values() + else: + state = last_context.current_state.values() + else: + state = () + defer.returnValue(state) + + def compute_state_delta(self, since_token, previous_state, current_state): + """ Works out the differnce in state between the current state and the + state the client got when it last performed a sync. + Returns: + A list of events. + """ + # TODO(mjark) Check if the state events were received by the server + # after the previous sync, since we need to include those state + # updates even if they occured logically before the previous event. + # TODO(mjark) Check for new redactions in the state events. + previous_dict = {event.event_id: event for event in previous_state} + state_delta = [] + for event in current_state: + if event.event_id not in previous_dict: + state_delta.append(event) + return state_delta + + @defer.inlineCallbacks + def check_joined_room(self, sync_config, room_id, state_delta): + joined = False + for event in state_delta: + if ( + event.type == EventTypes.Member + and event.state_key == sync_config.user.to_string() + ): + if event.content["membership"] == Membership.JOIN: + joined = True + + if joined: + res = yield self.state_handler.get_current_state(room_id) + state_delta = res.values() + + defer.returnValue(state_delta) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index ab698b36e1..c69787005f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -18,6 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.errors import SynapseError, AuthError +from synapse.types import UserID import logging @@ -83,9 +84,15 @@ class TypingNotificationHandler(BaseHandler): if member in self._member_typing_timer: self.clock.cancel_call_later(self._member_typing_timer[member]) + def _cb(): + logger.debug( + "%s has timed out in %s", target_user.to_string(), room_id + ) + self._stopped_typing(member) + self._member_typing_until[member] = until self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000, lambda: self._stopped_typing(member) + timeout / 1000.0, _cb ) if was_present: @@ -114,6 +121,10 @@ class TypingNotificationHandler(BaseHandler): member = RoomMember(room_id=room_id, user=target_user) + if member in self._member_typing_timer: + self.clock.cancel_call_later(self._member_typing_timer[member]) + del self._member_typing_timer[member] + yield self._stopped_typing(member) @defer.inlineCallbacks @@ -136,8 +147,10 @@ class TypingNotificationHandler(BaseHandler): del self._member_typing_until[member] - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] + if member in self._member_typing_timer: + # Don't cancel it - either it already expired, or the real + # stopped_typing() will cancel it + del self._member_typing_timer[member] @defer.inlineCallbacks def _push_update(self, room_id, user, typing): @@ -173,7 +186,7 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = self.homeserver.parse_userid(content["user_id"]) + user = UserID.from_string(content["user_id"]) localusers = set() |