diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/__init__.py | 45 | ||||
-rw-r--r-- | synapse/handlers/_base.py | 422 | ||||
-rw-r--r-- | synapse/handlers/appservice.py | 198 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 390 | ||||
-rw-r--r-- | synapse/handlers/device.py | 181 | ||||
-rw-r--r-- | synapse/handlers/devicemessage.py | 117 | ||||
-rw-r--r-- | synapse/handlers/directory.py | 97 | ||||
-rw-r--r-- | synapse/handlers/e2e_keys.py | 279 | ||||
-rw-r--r-- | synapse/handlers/events.py | 5 | ||||
-rw-r--r-- | synapse/handlers/federation.py | 714 | ||||
-rw-r--r-- | synapse/handlers/identity.py | 41 | ||||
-rw-r--r-- | synapse/handlers/initial_sync.py | 444 | ||||
-rw-r--r-- | synapse/handlers/message.py | 676 | ||||
-rw-r--r-- | synapse/handlers/presence.py | 364 | ||||
-rw-r--r-- | synapse/handlers/profile.py | 70 | ||||
-rw-r--r-- | synapse/handlers/receipts.py | 41 | ||||
-rw-r--r-- | synapse/handlers/register.py | 118 | ||||
-rw-r--r-- | synapse/handlers/room.py | 711 | ||||
-rw-r--r-- | synapse/handlers/room_list.py | 403 | ||||
-rw-r--r-- | synapse/handlers/room_member.py | 735 | ||||
-rw-r--r-- | synapse/handlers/search.py | 17 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 1361 | ||||
-rw-r--r-- | synapse/handlers/typing.py | 296 |
23 files changed, 4809 insertions, 2916 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 66d2c01123..5ad408f549 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,34 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice.scheduler import AppServiceScheduler -from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( - RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler, + RoomCreationHandler, RoomContextHandler, ) +from .room_member import RoomMemberHandler from .message import MessageHandler -from .events import EventStreamHandler, EventHandler from .federation import FederationHandler from .profile import ProfileHandler -from .presence import PresenceHandler from .directory import DirectoryHandler -from .typing import TypingNotificationHandler from .admin import AdminHandler -from .appservice import ApplicationServicesHandler -from .sync import SyncHandler -from .auth import AuthHandler from .identity import IdentityHandler -from .receipts import ReceiptsHandler from .search import SearchHandler class Handlers(object): - """ A collection of all the event handlers. + """ Deprecated. A collection of handlers. - There's no need to lazily create these; we'll just make them all eagerly - at construction time. + At some point most of the classes whose name ended "Handler" were + accessed through this class. + + However this makes it painful to unit test the handlers and to run cut + down versions of synapse that only use specific handlers because using a + single handler required creating all of the handlers. So some of the + handlers have been lifted out of the Handlers object and are now accessed + directly through the homeserver object itself. + + Any new handlers should follow the new pattern of being accessed through + the homeserver object and should not be added to the Handlers object. + + The remaining handlers should be moved out of the handlers object. """ def __init__(self, hs): @@ -48,26 +51,10 @@ class Handlers(object): self.message_handler = MessageHandler(hs) self.room_creation_handler = RoomCreationHandler(hs) self.room_member_handler = RoomMemberHandler(hs) - self.event_stream_handler = EventStreamHandler(hs) - self.event_handler = EventHandler(hs) self.federation_handler = FederationHandler(hs) self.profile_handler = ProfileHandler(hs) - self.presence_handler = PresenceHandler(hs) - self.room_list_handler = RoomListHandler(hs) self.directory_handler = DirectoryHandler(hs) - self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) - self.receipts_handler = ReceiptsHandler(hs) - asapi = ApplicationServiceApi(hs) - self.appservice_handler = ApplicationServicesHandler( - hs, asapi, AppServiceScheduler( - clock=hs.get_clock(), - store=hs.get_datastore(), - as_api=asapi - ) - ) - self.sync_handler = SyncHandler(hs) - self.auth_handler = AuthHandler(hs) self.identity_handler = IdentityHandler(hs) self.search_handler = SearchHandler(hs) self.room_context_handler = RoomContextHandler(hs) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90eabb6eb7..90f96209f8 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -13,39 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer -from synapse.api.errors import LimitExceededError, SynapseError, AuthError -from synapse.crypto.event_signing import add_hashes_and_signatures +import synapse.types from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, RoomAlias, Requester -from synapse.push.action_generator import ActionGenerator - -from synapse.util.logcontext import PreserveLoggingContext - -import logging +from synapse.api.errors import LimitExceededError +from synapse.types import UserID logger = logging.getLogger(__name__) -VISIBILITY_PRIORITY = ( - "world_readable", - "shared", - "invited", - "joined", -) - - class BaseHandler(object): """ Common base class for the event handlers. - :type store: synapse.storage.events.StateStore - :type state_handler: synapse.state.StateHandler + Attributes: + store (synapse.storage.DataStore): + state_handler (synapse.state.StateHandler): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.store = hs.get_datastore() self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -55,141 +49,26 @@ class BaseHandler(object): self.clock = hs.get_clock() self.hs = hs - self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname self.event_builder_factory = hs.get_event_builder_factory() - @defer.inlineCallbacks - def filter_events_for_clients(self, user_tuples, events, event_id_to_state): - """ Returns dict of user_id -> list of events that user is allowed to - see. - - :param (str, bool) user_tuples: (user id, is_peeking) for each - user to be checked. is_peeking should be true if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events - """ - forgotten = yield defer.gatherResults([ - self.store.who_forgot_in_room( - room_id, - ) - for room_id in frozenset(e.room_id for e in events) - ], consumeErrors=True) - - # Set of membership event_ids that have been forgotten - event_id_forgotten = frozenset( - row["event_id"] for rows in forgotten for row in rows - ) - - def allowed(event, user_id, is_peeking): - state = event_id_to_state[event.event_id] - - # get the room_visibility at the time of the event. - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) - if visibility_event: - visibility = visibility_event.content.get("history_visibility", "shared") - else: - visibility = "shared" - - if visibility not in VISIBILITY_PRIORITY: - visibility = "shared" - - # if it was world_readable, it's easy: everyone can read it - if visibility == "world_readable": - return True - - # Always allow history visibility events on boundaries. This is done - # by setting the effective visibility to the least restrictive - # of the old vs new. - if event.type == EventTypes.RoomHistoryVisibility: - prev_content = event.unsigned.get("prev_content", {}) - prev_visibility = prev_content.get("history_visibility", None) - - if prev_visibility not in VISIBILITY_PRIORITY: - prev_visibility = "shared" - - new_priority = VISIBILITY_PRIORITY.index(visibility) - old_priority = VISIBILITY_PRIORITY.index(prev_visibility) - if old_priority < new_priority: - visibility = prev_visibility - - # get the user's membership at the time of the event. (or rather, - # just *after* the event. Which means that people can see their - # own join events, but not (currently) their own leave events.) - membership_event = state.get((EventTypes.Member, user_id), None) - if membership_event: - if membership_event.event_id in event_id_forgotten: - membership = None - else: - membership = membership_event.membership - else: - membership = None - - # if the user was a member of the room at the time of the event, - # they can see it. - if membership == Membership.JOIN: - return True - - if visibility == "joined": - # we weren't a member at the time of the event, so we can't - # see this event. - return False - - elif visibility == "invited": - # user can also see the event if they were *invited* at the time - # of the event. - return membership == Membership.INVITE + def ratelimit(self, requester): + time_now = self.clock.time() + user_id = requester.user.to_string() - else: - # visibility is shared: user can also see the event if they have - # become a member since the event - # - # XXX: if the user has subsequently joined and then left again, - # ideally we would share history up to the point they left. But - # we don't know when they left. - return not is_peeking + # The AS user itself is never rate limited. + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders - defer.returnValue({ - user_id: [ - event - for event in events - if allowed(event, user_id, is_peeking) - ] - for user_id, is_peeking in user_tuples - }) + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return - @defer.inlineCallbacks - def _filter_events_for_client(self, user_id, events, is_peeking=False): - """ - Check which events a user is allowed to see - - :param str user_id: user id to be checked - :param [synapse.events.EventBase] events: list of events to be checked - :param bool is_peeking should be True if: - * the user is not currently a member of the room, and: - * the user has not been a member of the room since the given - events - :rtype [synapse.events.EventBase] - """ - types = ( - (EventTypes.RoomHistoryVisibility, ""), - (EventTypes.Member, user_id), - ) - event_id_to_state = yield self.store.get_state_for_events( - frozenset(e.event_id for e in events), - types=types - ) - res = yield self.filter_events_for_clients( - [(user_id, is_peeking)], events, event_id_to_state - ) - defer.returnValue(res.get(user_id, [])) - - def ratelimit(self, requester): - time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.send_message( - requester.user.to_string(), time_now, + user_id, time_now, msg_rate_hz=self.hs.config.rc_messages_per_second, burst_count=self.hs.config.rc_message_burst_count, ) @@ -199,252 +78,20 @@ class BaseHandler(object): ) @defer.inlineCallbacks - def _create_new_client_event(self, builder): - latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( - builder.room_id, - ) - - if latest_ret: - depth = max([d for _, _, d in latest_ret]) + 1 - else: - depth = 1 - - prev_events = [ - (event_id, prev_hashes) - for event_id, prev_hashes, _ in latest_ret - ] - - builder.prev_events = prev_events - builder.depth = depth - - state_handler = self.state_handler - - context = yield state_handler.compute_event_context(builder) - - # If we've received an invite over federation, there are no latest - # events in the room, because we don't know enough about the graph - # fragment we received to treat it like a graph, so the above returned - # no relevant events. It may have returned some events (if we have - # joined and left the room), but not useful ones, like the invite. - if ( - not self.is_host_in_room(context.current_state) and - builder.type == EventTypes.Member - ): - prev_member_event = yield self.store.get_room_member( - builder.sender, builder.room_id - ) - - # The prev_member_event may already be in context.current_state, - # despite us not being present in the room; in particular, if - # inviting user, and all other local users, have already left. - # - # In that case, we have all the information we need, and we don't - # want to drop "context" - not least because we may need to handle - # the invite locally, which will require us to have the whole - # context (not just prev_member_event) to auth it. - # - context_event_ids = ( - e.event_id for e in context.current_state.values() - ) - - if ( - prev_member_event and - prev_member_event.event_id not in context_event_ids - ): - # The prev_member_event is missing from context, so it must - # have arrived over federation and is an outlier. We forcibly - # set our context to the invite we received over federation - builder.prev_events = ( - prev_member_event.event_id, - prev_member_event.prev_events - ) - - context = yield state_handler.compute_event_context( - builder, - old_state=(prev_member_event,), - outlier=True - ) - - if builder.is_state(): - builder.prev_state = yield self.store.add_event_hashes( - context.prev_state_events - ) - - yield self.auth.add_auth_events(builder, context) - - add_hashes_and_signatures( - builder, self.server_name, self.signing_key - ) - - event = builder.build() - - logger.debug( - "Created event %s with current state: %s", - event.event_id, context.current_state, - ) - - defer.returnValue( - (event, context,) - ) - - def is_host_in_room(self, current_state): - room_members = [ - (state_key, event.membership) - for ((event_type, state_key), event) in current_state.items() - if event_type == EventTypes.Member - ] - if len(room_members) == 0: - # Have we just created the room, and is this about to be the very - # first member event? - create_event = current_state.get(("m.room.create", "")) - if create_event: - return True - for (state_key, membership) in room_members: - if ( - UserID.from_string(state_key).domain == self.hs.hostname - and membership == Membership.JOIN - ): - return True - return False - - @defer.inlineCallbacks - def handle_new_client_event( - self, - requester, - event, - context, - ratelimit=True, - extra_users=[] - ): - # We now need to go and hit out to wherever we need to hit out to. - - if ratelimit: - self.ratelimit(requester) - - self.auth.check(event, auth_events=context.current_state) - - yield self.maybe_kick_guest_users(event, context.current_state.values()) - - if event.type == EventTypes.CanonicalAlias: - # Check the alias is acually valid (at this time at least) - room_alias_str = event.content.get("alias", None) - if room_alias_str: - room_alias = RoomAlias.from_string(room_alias_str) - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if mapping["room_id"] != event.room_id: - raise SynapseError( - 400, - "Room alias %s does not point to the room" % ( - room_alias_str, - ) - ) - - federation_handler = self.hs.get_handlers().federation_handler - - if event.type == EventTypes.Member: - if event.content["membership"] == Membership.INVITE: - def is_inviter_member_event(e): - return ( - e.type == EventTypes.Member and - e.sender == event.sender - ) - - event.unsigned["invite_room_state"] = [ - { - "type": e.type, - "state_key": e.state_key, - "content": e.content, - "sender": e.sender, - } - for k, e in context.current_state.items() - if e.type in self.hs.config.room_invite_state_types - or is_inviter_member_event(e) - ] - - invitee = UserID.from_string(event.state_key) - if not self.hs.is_mine(invitee): - # TODO: Can we add signature from remote server in a nicer - # way? If we have been invited by a remote server, we need - # to get them to sign the event. - - returned_invite = yield federation_handler.send_invite( - invitee.domain, - event, - ) - - event.unsigned.pop("room_state", None) - - # TODO: Make sure the signatures actually are correct. - event.signatures.update( - returned_invite.signatures - ) - - if event.type == EventTypes.Redaction: - if self.auth.check_redaction(event, auth_events=context.current_state): - original_event = yield self.store.get_event( - event.redacts, - check_redacted=False, - get_prev_content=False, - allow_rejected=False, - allow_none=False - ) - if event.user_id != original_event.user_id: - raise AuthError( - 403, - "You don't have permission to redact events" - ) - - if event.type == EventTypes.Create and context.current_state: - raise AuthError( - 403, - "Changing the room create event is forbidden", - ) - - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( - event, context, self - ) - - (event_stream_id, max_stream_id) = yield self.store.persist_event( - event, context=context - ) - - destinations = set() - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add( - UserID.from_string(s.state_key).domain - ) - except SynapseError: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - with PreserveLoggingContext(): - # Don't block waiting on waking up all the listeners. - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) - - # If invite, remove room_state from unsigned before sending. - event.unsigned.pop("invite_room_state", None) - - federation_handler.handle_new_event( - event, destinations=destinations, - ) - - @defer.inlineCallbacks - def maybe_kick_guest_users(self, event, current_state): + def maybe_kick_guest_users(self, event, context=None): # Technically this function invalidates current_state by changing it. # Hopefully this isn't that important to the caller. if event.type == EventTypes.GuestAccess: guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": + if context: + current_state = yield self.store.get_events( + context.current_state_ids.values() + ) + current_state = current_state.values() + else: + current_state = yield self.store.get_current_state(event.room_id) + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) @defer.inlineCallbacks @@ -477,7 +124,8 @@ class BaseHandler(object): # and having homeservers have their own users leave keeps more # of that decision-making and control local to the guest-having # homeserver. - requester = Requester(target_user, "", True) + requester = synapse.types.create_requester( + target_user, is_guest=True) handler = self.hs.get_handlers().room_member_handler yield handler.update_membership( requester, diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 75fc74c797..05af54d31b 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -16,8 +16,8 @@ from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.appservice import ApplicationService -from synapse.types import UserID +from synapse.util.metrics import Measure +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred import logging @@ -35,47 +35,81 @@ def log_failure(failure): ) -# NB: Purposefully not inheriting BaseHandler since that contains way too much -# setup code which this handler does not need or use. This makes testing a lot -# easier. class ApplicationServicesHandler(object): - def __init__(self, hs, appservice_api, appservice_scheduler): + def __init__(self, hs): self.store = hs.get_datastore() - self.hs = hs - self.appservice_api = appservice_api - self.scheduler = appservice_scheduler + self.is_mine_id = hs.is_mine_id + self.appservice_api = hs.get_application_service_api() + self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False + self.clock = hs.get_clock() + self.notify_appservices = hs.config.notify_appservices + + self.current_max = 0 + self.is_processing = False @defer.inlineCallbacks - def notify_interested_services(self, event): + def notify_interested_services(self, current_id): """Notifies (pushes) all application services interested in this event. Pushing is done asynchronously, so this method won't block for any prolonged length of time. Args: - event(Event): The event to push out to interested services. + current_id(int): The current maximum ID. """ - # Gather interested services - services = yield self._get_services_for_event(event) - if len(services) == 0: - return # no services need notifying - - # Do we know this user exists? If not, poke the user query API for - # all services which match that user regex. This needs to block as these - # user queries need to be made BEFORE pushing the event. - yield self._check_user_exists(event.sender) - if event.type == EventTypes.Member: - yield self._check_user_exists(event.state_key) - - if not self.started_scheduler: - self.scheduler.start().addErrback(log_failure) - self.started_scheduler = True - - # Fork off pushes to these services - for service in services: - self.scheduler.submit_event_for_as(service, event) + services = self.store.get_app_services() + if not services or not self.notify_appservices: + return + + self.current_max = max(self.current_max, current_id) + if self.is_processing: + return + + with Measure(self.clock, "notify_interested_services"): + self.is_processing = True + try: + upper_bound = self.current_max + limit = 100 + while True: + upper_bound, events = yield self.store.get_new_events_for_appservice( + upper_bound, limit + ) + + if not events: + break + + for event in events: + # Gather interested services + services = yield self._get_services_for_event(event) + if len(services) == 0: + continue # no services need notifying + + # Do we know this user exists? If not, poke the user + # query API for all services which match that user regex. + # This needs to block as these user queries need to be + # made BEFORE pushing the event. + yield self._check_user_exists(event.sender) + if event.type == EventTypes.Member: + yield self._check_user_exists(event.state_key) + + if not self.started_scheduler: + self.scheduler.start().addErrback(log_failure) + self.started_scheduler = True + + # Fork off pushes to these services + for service in services: + preserve_fn(self.scheduler.submit_event_for_as)( + service, event + ) + + yield self.store.set_appservice_last_pos(upper_bound) + + if len(events) < limit: + break + finally: + self.is_processing = False @defer.inlineCallbacks def query_user_exists(self, user_id): @@ -108,11 +142,12 @@ class ApplicationServicesHandler(object): association can be found. """ room_alias_str = room_alias.to_string() - alias_query_services = yield self._get_services_for_event( - event=None, - restrict_to=ApplicationService.NS_ALIASES, - alias_list=[room_alias_str] - ) + services = self.store.get_app_services() + alias_query_services = [ + s for s in services if ( + s.is_interested_in_alias(room_alias_str) + ) + ] for alias_service in alias_query_services: is_known_alias = yield self.appservice_api.query_alias( alias_service, room_alias_str @@ -125,52 +160,97 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def _get_services_for_event(self, event, restrict_to="", alias_list=None): + def query_3pe(self, kind, protocol, fields): + services = yield self._get_services_for_3pn(protocol) + + results = yield preserve_context_over_deferred(defer.DeferredList([ + preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) + for service in services + ], consumeErrors=True)) + + ret = [] + for (success, result) in results: + if success: + ret.extend(result) + + defer.returnValue(ret) + + @defer.inlineCallbacks + def get_3pe_protocols(self, only_protocol=None): + services = self.store.get_app_services() + protocols = {} + + # Collect up all the individual protocol responses out of the ASes + for s in services: + for p in s.protocols: + if only_protocol is not None and p != only_protocol: + continue + + if p not in protocols: + protocols[p] = [] + + info = yield self.appservice_api.get_3pe_protocol(s, p) + + if info is not None: + protocols[p].append(info) + + def _merge_instances(infos): + if not infos: + return {} + + # Merge the 'instances' lists of multiple results, but just take + # the other fields from the first as they ought to be identical + # copy the result so as not to corrupt the cached one + combined = dict(infos[0]) + combined["instances"] = list(combined["instances"]) + + for info in infos[1:]: + combined["instances"].extend(info["instances"]) + + return combined + + for p in protocols.keys(): + protocols[p] = _merge_instances(protocols[p]) + + defer.returnValue(protocols) + + @defer.inlineCallbacks + def _get_services_for_event(self, event): """Retrieve a list of application services interested in this event. Args: event(Event): The event to check. Can be None if alias_list is not. - restrict_to(str): The namespace to restrict regex tests to. - alias_list: A list of aliases to get services for. If None, this - list is obtained from the database. Returns: list<ApplicationService>: A list of services interested in this event based on the service regex. """ - member_list = None - if hasattr(event, "room_id"): - # We need to know the aliases associated with this event.room_id, - # if any. - if not alias_list: - alias_list = yield self.store.get_aliases_for_room( - event.room_id - ) - # We need to know the members associated with this event.room_id, - # if any. - member_list = yield self.store.get_users_in_room(event.room_id) - - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( - s.is_interested(event, restrict_to, alias_list, member_list) + yield s.is_interested(event, self.store) ) ] defer.returnValue(interested_list) - @defer.inlineCallbacks def _get_services_for_user(self, user_id): - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_list = [ s for s in services if ( s.is_interested_in_user(user_id) ) ] - defer.returnValue(interested_list) + return defer.succeed(interested_list) + + def _get_services_for_3pn(self, protocol): + services = self.store.get_app_services() + interested_list = [ + s for s in services if s.is_interested_in_protocol(protocol) + ] + return defer.succeed(interested_list) @defer.inlineCallbacks def _is_unknown_user(self, user_id): - user = UserID.from_string(user_id) - if not self.hs.is_mine(user): + if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. defer.returnValue(False) @@ -182,7 +262,7 @@ class ApplicationServicesHandler(object): return # user not found; could be the AS though, so check. - services = yield self.store.get_app_services() + services = self.store.get_app_services() service_list = [s for s in services if s.sender == user_id] defer.returnValue(len(service_list) == 0) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 82d458b424..3b146f09d6 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,7 +18,7 @@ from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType from synapse.types import UserID -from synapse.api.errors import AuthError, LoginError, Codes +from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.util.async import run_on_reactor from twisted.web.client import PartialDownloadError @@ -38,6 +38,10 @@ class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(AuthHandler, self).__init__(hs) self.checkers = { LoginType.PASSWORD: self._check_password_auth, @@ -47,7 +51,20 @@ class AuthHandler(BaseHandler): } self.bcrypt_rounds = hs.config.bcrypt_rounds self.sessions = {} - self.INVALID_TOKEN_HTTP_STATUS = 401 + + account_handler = _AccountHandler( + hs, check_user_exists=self.check_user_exists + ) + + self.password_providers = [ + module(config=config, account_handler=account_handler) + for module, config in hs.config.password_providers + ] + + logger.info("Extra password_providers: %r", self.password_providers) + + self.hs = hs # FIXME better possibility to access registrationHandler later? + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): @@ -118,21 +135,47 @@ class AuthHandler(BaseHandler): creds = session['creds'] # check auth type currently being presented + errordict = {} if 'type' in authdict: - if authdict['type'] not in self.checkers: + login_type = authdict['type'] + if login_type not in self.checkers: raise LoginError(400, "", Codes.UNRECOGNIZED) - result = yield self.checkers[authdict['type']](authdict, clientip) - if result: - creds[authdict['type']] = result - self._save_session(session) + try: + result = yield self.checkers[login_type](authdict, clientip) + if result: + creds[login_type] = result + self._save_session(session) + except LoginError, e: + if login_type == LoginType.EMAIL_IDENTITY: + # riot used to have a bug where it would request a new + # validation token (thus sending a new email) each time it + # got a 401 with a 'flows' field. + # (https://github.com/vector-im/vector-web/issues/2447). + # + # Grandfather in the old behaviour for now to avoid + # breaking old riot deployments. + raise e + + # this step failed. Merge the error dict into the response + # so that the client can have another go. + errordict = e.error_dict() for f in flows: if len(set(f) - set(creds.keys())) == 0: - logger.info("Auth completed with creds: %r", creds) + # it's very useful to know what args are stored, but this can + # include the password in the case of registering, so only log + # the keys (confusingly, clientdict may contain a password + # param, creds is just what the user authed as for UI auth + # and is not sensitive). + logger.info( + "Auth completed with creds: %r. Client dict has keys: %r", + creds, clientdict.keys() + ) defer.returnValue((True, creds, clientdict, session['id'])) ret = self._auth_dict_for_flows(flows, session) ret['completed'] = creds.keys() + ret.update(errordict) defer.returnValue((False, ret, clientdict, session['id'])) @defer.inlineCallbacks @@ -163,9 +206,13 @@ class AuthHandler(BaseHandler): def get_session_id(self, clientdict): """ Gets the session ID for a client given the client dictionary - :param clientdict: The dictionary sent by the client in the request - :return: The string session ID the client sent. If the client did not - send a session ID, returns None. + + Args: + clientdict: The dictionary sent by the client in the request + + Returns: + str|None: The string session ID the client sent. If the client did + not send a session ID, returns None. """ sid = None if clientdict and 'auth' in clientdict: @@ -179,9 +226,11 @@ class AuthHandler(BaseHandler): Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param value: (any) The data to store + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + value (any): The data to store """ sess = self._get_session_info(session_id) sess.setdefault('serverdict', {})[key] = value @@ -190,14 +239,15 @@ class AuthHandler(BaseHandler): def get_session_data(self, session_id, key, default=None): """ Retrieve data stored with set_session_data - :param session_id: (string) The ID of this session as returned from check_auth - :param key: (string) The key to store the data under - :param default: (any) Value to return if the key has not been set + + Args: + session_id (string): The ID of this session as returned from check_auth + key (string): The key to store the data under + default (any): Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) - @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) @@ -207,9 +257,7 @@ class AuthHandler(BaseHandler): if not user_id.startswith('@'): user_id = UserID.create(user_id, self.hs.hostname).to_string() - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) - defer.returnValue(user_id) + return self._check_password(user_id, password) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -245,8 +293,17 @@ class AuthHandler(BaseHandler): data = pde.response resp_body = simplejson.loads(data) - if 'success' in resp_body and resp_body['success']: - defer.returnValue(True) + if 'success' in resp_body: + # Note that we do NOT check the hostname here: we explicitly + # intend the CAPTCHA to be presented by whatever client the + # user is using, we just care that they have completed a CAPTCHA. + logger.info( + "%s reCAPTCHA from hostname %s", + "Successful" if resp_body['success'] else "Failed", + resp_body.get('hostname') + ) + if resp_body['success']: + defer.returnValue(True) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) @defer.inlineCallbacks @@ -313,147 +370,205 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def login_with_password(self, user_id, password): + def validate_password_login(self, user_id, password): """ Authenticates the user with their username and password. Used only by the v1 login API. Args: - user_id (str): User ID + user_id (str): complete @user:id password (str): Password Returns: - A tuple of: - The user's ID. - The access token for the user's session. - The refresh token for the user's session. + defer.Deferred: (str) canonical user id Raises: - StoreError if there was a problem storing the token. + StoreError if there was a problem accessing the database LoginError if there was an authentication problem. """ - user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) - self._check_password(user_id, password, password_hash) - - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + return self._check_password(user_id, password) @defer.inlineCallbacks - def get_login_tuple_for_user_id(self, user_id): + def get_access_token_for_user_id(self, user_id, device_id=None, + initial_display_name=None): """ - Gets login tuple for the user with the given user ID. + Creates a new access token for the user with the given user ID. + The user is assumed to have been authenticated by some other - machanism (e.g. CAS) + machanism (e.g. CAS), and the user_id converted to the canonical case. + + The device will be recorded in the table if it is not there already. Args: - user_id (str): User ID + user_id (str): canonical User ID + device_id (str|None): the device ID to associate with the tokens. + None to leave the tokens unassociated with a device (deprecated: + we should always have a device ID) + initial_display_name (str): display name to associate with the + device if it needs re-registering Returns: - A tuple of: - The user's ID. The access token for the user's session. - The refresh token for the user's session. Raises: StoreError if there was a problem storing the token. LoginError if there was an authentication problem. """ - user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) + logger.info("Logging in user %s on device %s", user_id, device_id) + access_token = yield self.issue_access_token(user_id, device_id) + + # the device *should* have been registered before we got here; however, + # it's possible we raced against a DELETE operation. The thing we + # really don't want is active access_tokens without a record of the + # device, so we double-check it here. + if device_id is not None: + yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) - logger.info("Logging in user %s", user_id) - access_token = yield self.issue_access_token(user_id) - refresh_token = yield self.issue_refresh_token(user_id) - defer.returnValue((user_id, access_token, refresh_token)) + defer.returnValue(access_token) @defer.inlineCallbacks - def does_user_exist(self, user_id): - try: - yield self._find_user_id_and_pwd_hash(user_id) - defer.returnValue(True) - except LoginError: - defer.returnValue(False) + def check_user_exists(self, user_id): + """ + Checks to see if a user with the given id exists. Will check case + insensitively, but return None if there are multiple inexact matches. + + Args: + (str) user_id: complete @user:id + + Returns: + defer.Deferred: (str) canonical_user_id, or None if zero or + multiple matches + """ + res = yield self._find_user_id_and_pwd_hash(user_id) + if res is not None: + defer.returnValue(res[0]) + defer.returnValue(None) @defer.inlineCallbacks def _find_user_id_and_pwd_hash(self, user_id): """Checks to see if a user with the given id exists. Will check case - insensitively, but will throw if there are multiple inexact matches. + insensitively, but will return None if there are multiple inexact + matches. Returns: tuple: A 2-tuple of `(canonical_user_id, password_hash)` + None: if there is not exactly one match """ user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + + result = None if not user_infos: logger.warn("Attempted to login as %s but they do not exist", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - if len(user_infos) > 1: - if user_id not in user_infos: - logger.warn( - "Attempted to login as %s but it matches more than one user " - "inexactly: %r", - user_id, user_infos.keys() - ) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - defer.returnValue((user_id, user_infos[user_id])) + elif len(user_infos) == 1: + # a single match (possibly not exact) + result = user_infos.popitem() + elif user_id in user_infos: + # multiple matches, but one is exact + result = (user_id, user_infos[user_id]) else: - defer.returnValue(user_infos.popitem()) + # multiple matches, none of them exact + logger.warn( + "Attempted to login as %s but it matches more than one user " + "inexactly: %r", + user_id, user_infos.keys() + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _check_password(self, user_id, password): + """Authenticate a user against the LDAP and local databases. - def _check_password(self, user_id, password, stored_hash): - """Checks that user_id has passed password, raises LoginError if not.""" - if not self.validate_hash(password, stored_hash): + user_id is checked case insensitively against the local database, but + will throw if there are multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id + Raises: + LoginError if login fails + """ + for provider in self.password_providers: + is_valid = yield provider.check_password(user_id, password) + if is_valid: + defer.returnValue(user_id) + + canonical_user_id = yield self._check_local_password(user_id, password) + + if canonical_user_id: + defer.returnValue(canonical_user_id) + + # unknown username or invalid password. We raise a 403 here, but note + # that if we're doing user-interactive login, it turns all LoginErrors + # into a 401 anyway. + raise LoginError( + 403, "Invalid password", + errcode=Codes.FORBIDDEN + ) + + @defer.inlineCallbacks + def _check_local_password(self, user_id, password): + """Authenticate a user against the local password database. + + user_id is checked case insensitively, but will return None if there are + multiple inexact matches. + + Args: + user_id (str): complete @user:id + Returns: + (str) the canonical_user_id, or None if unknown user / bad password + """ + lookupres = yield self._find_user_id_and_pwd_hash(user_id) + if not lookupres: + defer.returnValue(None) + (user_id, password_hash) = lookupres + result = self.validate_hash(password, password_hash) + if not result: logger.warn("Failed password login for user %s", user_id) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) + defer.returnValue(None) + defer.returnValue(user_id) @defer.inlineCallbacks - def issue_access_token(self, user_id): + def issue_access_token(self, user_id, device_id=None): access_token = self.generate_access_token(user_id) - yield self.store.add_access_token_to_user(user_id, access_token) + yield self.store.add_access_token_to_user(user_id, access_token, + device_id) defer.returnValue(access_token) - @defer.inlineCallbacks - def issue_refresh_token(self, user_id): - refresh_token = self.generate_refresh_token(user_id) - yield self.store.add_refresh_token_to_user(user_id, refresh_token) - defer.returnValue(refresh_token) - def generate_access_token(self, user_id, extra_caveats=None): extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") - now = self.hs.get_clock().time_msec() - expiry = now + (60 * 60 * 1000) - macaroon.add_first_party_caveat("time < %d" % (expiry,)) + # Include a nonce, to make sure that each login gets a different + # access token. + macaroon.add_first_party_caveat("nonce = %s" % ( + stringutils.random_string_with_symbols(16), + )) for caveat in extra_caveats: macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_refresh_token(self, user_id): - m = self._generate_base_macaroon(user_id) - m.add_first_party_caveat("type = refresh") - # Important to add a nonce, because otherwise every refresh token for a - # user will be the same. - m.add_first_party_caveat("nonce = %s" % ( - stringutils.random_string_with_symbols(16), - )) - return m.serialize() - - def generate_short_term_login_token(self, user_id): + def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() - expiry = now + (2 * 60 * 1000) + expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() + def generate_delete_pusher_token(self, user_id): + macaroon = self._generate_base_macaroon(user_id) + macaroon.add_first_party_caveat("type = delete_pusher") + return macaroon.serialize() + def validate_short_term_login_token_and_get_user_id(self, login_token): + auth_api = self.hs.get_auth() try: macaroon = pymacaroons.Macaroon.deserialize(login_token) - auth_api = self.hs.get_auth() - auth_api.validate_macaroon(macaroon, "login", True) - return self.get_user_from_macaroon(macaroon) - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): - raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) + user_id = auth_api.get_user_id_from_macaroon(macaroon) + auth_api.validate_macaroon(macaroon, "login", True, user_id) + return user_id + except Exception: + raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) def _generate_base_macaroon(self, user_id): macaroon = pymacaroons.Macaroon( @@ -464,32 +579,39 @@ class AuthHandler(BaseHandler): macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - def get_user_from_macaroon(self, macaroon): - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix):] - raise AuthError( - self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token", - errcode=Codes.UNKNOWN_TOKEN - ) - @defer.inlineCallbacks def set_password(self, user_id, newpassword, requester=None): password_hash = self.hash(newpassword) - except_access_token_ids = [requester.access_token_id] if requester else [] + except_access_token_id = requester.access_token_id if requester else None - yield self.store.user_set_password_hash(user_id, password_hash) + try: + yield self.store.user_set_password_hash(user_id, password_hash) + except StoreError as e: + if e.code == 404: + raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) + raise e yield self.store.user_delete_access_tokens( - user_id, except_access_token_ids + user_id, except_access_token_id ) yield self.hs.get_pusherpool().remove_pushers_by_user( - user_id, except_access_token_ids + user_id, except_access_token_id ) @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): + # 'Canonicalise' email addresses down to lower case. + # We've now moving towards the Home Server being the entity that + # is responsible for validating threepids used for resetting passwords + # on accounts, so in future Synapse will gain knowledge of specific + # types (mediums) of threepid. For now, we still use the existing + # infrastructure, but this is the start of synapse gaining knowledge + # of specific types of threepid (and fixes the fact that checking + # for the presenc eof an email address during password reset was + # case sensitive). + if medium == 'email': + address = address.lower() + yield self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() @@ -520,7 +642,8 @@ class AuthHandler(BaseHandler): Returns: Hashed password (str). """ - return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) + return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, + bcrypt.gensalt(self.bcrypt_rounds)) def validate_hash(self, password, stored_hash): """Validates that self.hash(password) == stored_hash. @@ -532,4 +655,35 @@ class AuthHandler(BaseHandler): Returns: Whether self.hash(password) == stored_hash (bool). """ - return bcrypt.hashpw(password, stored_hash) == stored_hash + if stored_hash: + return bcrypt.hashpw(password + self.hs.config.password_pepper, + stored_hash.encode('utf-8')) == stored_hash + else: + return False + + +class _AccountHandler(object): + """A proxy object that gets passed to password auth providers so they + can register new users etc if necessary. + """ + def __init__(self, hs, check_user_exists): + self.hs = hs + + self._check_user_exists = check_user_exists + + def check_user_exists(self, user_id): + """Check if user exissts. + + Returns: + Deferred(bool) + """ + return self._check_user_exists(user_id) + + def register(self, localpart): + """Registers a new user with given localpart + + Returns: + Deferred: a 2-tuple of (user_id, access_token) + """ + reg = self.hs.get_handlers().registration_handler + return reg.register(localpart=localpart) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py new file mode 100644 index 0000000000..aa68755936 --- /dev/null +++ b/synapse/handlers/device.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api import errors +from synapse.util import stringutils +from twisted.internet import defer +from ._base import BaseHandler + +import logging + +logger = logging.getLogger(__name__) + + +class DeviceHandler(BaseHandler): + def __init__(self, hs): + super(DeviceHandler, self).__init__(hs) + + @defer.inlineCallbacks + def check_device_registered(self, user_id, device_id, + initial_device_display_name=None): + """ + If the given device has not been registered, register it with the + supplied display name. + + If no device_id is supplied, we make one up. + + Args: + user_id (str): @user:id + device_id (str | None): device id supplied by client + initial_device_display_name (str | None): device display name from + client + Returns: + str: device id (generated if none was supplied) + """ + if device_id is not None: + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=True, + ) + defer.returnValue(device_id) + + # if the device id is not specified, we'll autogen one, but loop a few + # times in case of a clash. + attempts = 0 + while attempts < 5: + try: + device_id = stringutils.random_string(10).upper() + yield self.store.store_device( + user_id=user_id, + device_id=device_id, + initial_device_display_name=initial_device_display_name, + ignore_if_known=False, + ) + defer.returnValue(device_id) + except errors.StoreError: + attempts += 1 + + raise errors.StoreError(500, "Couldn't generate a device ID.") + + @defer.inlineCallbacks + def get_devices_by_user(self, user_id): + """ + Retrieve the given user's devices + + Args: + user_id (str): + Returns: + defer.Deferred: list[dict[str, X]]: info on each device + """ + + device_map = yield self.store.get_devices_by_user(user_id) + + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id) for device_id in device_map.keys()) + ) + + devices = device_map.values() + for device in devices: + _update_device_from_client_ips(device, ips) + + defer.returnValue(devices) + + @defer.inlineCallbacks + def get_device(self, user_id, device_id): + """ Retrieve the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: dict[str, X]: info on the device + Raises: + errors.NotFoundError: if the device was not found + """ + try: + device = yield self.store.get_device(user_id, device_id) + except errors.StoreError: + raise errors.NotFoundError + ips = yield self.store.get_last_client_ip_by_device( + devices=((user_id, device_id),) + ) + _update_device_from_client_ips(device, ips) + defer.returnValue(device) + + @defer.inlineCallbacks + def delete_device(self, user_id, device_id): + """ Delete the given device + + Args: + user_id (str): + device_id (str): + + Returns: + defer.Deferred: + """ + + try: + yield self.store.delete_device(user_id, device_id) + except errors.StoreError, e: + if e.code == 404: + # no match + pass + else: + raise + + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) + + yield self.store.delete_e2e_keys_by_device( + user_id=user_id, device_id=device_id + ) + + @defer.inlineCallbacks + def update_device(self, user_id, device_id, content): + """ Update the given device + + Args: + user_id (str): + device_id (str): + content (dict): body of update request + + Returns: + defer.Deferred: + """ + + try: + yield self.store.update_device( + user_id, + device_id, + new_display_name=content.get("display_name") + ) + except errors.StoreError, e: + if e.code == 404: + raise errors.NotFoundError() + else: + raise + + +def _update_device_from_client_ips(device, client_ips): + ip = client_ips.get((device["user_id"], device["device_id"]), {}) + device.update({ + "last_seen_ts": ip.get("last_seen"), + "last_seen_ip": ip.get("ip"), + }) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py new file mode 100644 index 0000000000..f7fad15c62 --- /dev/null +++ b/synapse/handlers/devicemessage.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.types import get_domain_from_id +from synapse.util.stringutils import random_string + + +logger = logging.getLogger(__name__) + + +class DeviceMessageHandler(object): + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.is_mine_id = hs.is_mine_id + self.federation = hs.get_federation_sender() + + hs.get_replication_layer().register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + + @defer.inlineCallbacks + def on_direct_to_device_edu(self, origin, content): + local_messages = {} + sender_user_id = content["sender"] + if origin != get_domain_from_id(sender_user_id): + logger.warn( + "Dropping device message from %r with spoofed sender %r", + origin, sender_user_id + ) + message_type = content["type"] + message_id = content["message_id"] + for user_id, by_device in content["messages"].items(): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + + stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + origin, message_id, local_messages + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + @defer.inlineCallbacks + def send_device_message(self, sender_user_id, message_type, messages): + + local_messages = {} + remote_messages = {} + for user_id, by_device in messages.items(): + if self.is_mine_id(user_id): + messages_by_device = { + device_id: { + "content": message_content, + "type": message_type, + "sender": sender_user_id, + } + for device_id, message_content in by_device.items() + } + if messages_by_device: + local_messages[user_id] = messages_by_device + else: + destination = get_domain_from_id(user_id) + remote_messages.setdefault(destination, {})[user_id] = by_device + + message_id = random_string(16) + + remote_edu_contents = {} + for destination, messages in remote_messages.items(): + remote_edu_contents[destination] = { + "messages": messages, + "sender": sender_user_id, + "type": message_type, + "message_id": message_id, + } + + stream_id = yield self.store.add_messages_to_device_inbox( + local_messages, remote_edu_contents + ) + + self.notifier.on_new_event( + "to_device_key", stream_id, users=local_messages.keys() + ) + + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation.send_device_messages(destination) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index c4aaa11918..c00274afc3 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,7 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, get_domain_from_id import logging import string @@ -32,6 +32,9 @@ class DirectoryHandler(BaseHandler): def __init__(self, hs): super(DirectoryHandler, self).__init__(hs) + self.state = hs.get_state_handler() + self.appservice_handler = hs.get_application_service_handler() + self.federation = hs.get_replication_layer() self.federation.register_query_handler( "directory", self.on_directory_query @@ -52,7 +55,8 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) if not servers: raise SynapseError(400, "Failed to get server list") @@ -93,7 +97,7 @@ class DirectoryHandler(BaseHandler): yield self._create_association(room_alias, room_id, servers) @defer.inlineCallbacks - def delete_association(self, user_id, room_alias): + def delete_association(self, requester, user_id, room_alias): # association deletion for human users can_delete = yield self._user_can_delete_alias(room_alias, user_id) @@ -112,7 +116,25 @@ class DirectoryHandler(BaseHandler): errcode=Codes.EXCLUSIVE ) - yield self._delete_association(room_alias) + room_id = yield self._delete_association(room_alias) + + try: + yield self.send_room_alias_update_event( + requester, + requester.user.to_string(), + room_id + ) + + yield self._update_canonical_alias( + requester, + requester.user.to_string(), + room_id, + room_alias, + ) + except AuthError as e: + logger.info("Failed to update alias events: %s", e) + + defer.returnValue(room_id) @defer.inlineCallbacks def delete_appservice_association(self, service, room_alias): @@ -129,11 +151,9 @@ class DirectoryHandler(BaseHandler): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") - yield self.store.delete_room_alias(room_alias) + room_id = yield self.store.delete_room_alias(room_alias) - # TODO - Looks like _update_room_alias_event has never been implemented - # if room_id: - # yield self._update_room_alias_events(user_id, room_id) + defer.returnValue(room_id) @defer.inlineCallbacks def get_association(self, room_alias): @@ -174,7 +194,8 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND ) - extra_servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + extra_servers = set(get_domain_from_id(u) for u in users) servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. @@ -234,23 +255,45 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks + def _update_canonical_alias(self, requester, user_id, room_id, room_alias): + alias_event = yield self.state.get_current_state( + room_id, EventTypes.CanonicalAlias, "" + ) + + alias_str = room_alias.to_string() + if not alias_event or alias_event.content.get("alias", "") != alias_str: + return + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.CanonicalAlias, + "state_key": "", + "room_id": room_id, + "sender": user_id, + "content": {}, + }, + ratelimit=False + ) + + @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): result = yield self.store.get_association_from_room_alias( room_alias ) if not result: # Query AS to see if it exists - as_handler = self.hs.get_handlers().appservice_handler + as_handler = self.appservice_handler result = yield as_handler.query_room_alias_exists(room_alias) defer.returnValue(result) - @defer.inlineCallbacks def can_modify_alias(self, alias, user_id=None): # Any application service "interested" in an alias they are regexing on # can modify the alias. # Users can only modify the alias if ALL the interested services have # non-exclusive locks on the alias (or there are no interested services) - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_alias(alias.to_string()) ] @@ -258,14 +301,12 @@ class DirectoryHandler(BaseHandler): for service in interested_services: if user_id == service.sender: # this user IS the app service so they can do whatever they like - defer.returnValue(True) - return + return defer.succeed(True) elif service.is_exclusive_alias(alias.to_string()): # another service has an exclusive lock on this alias. - defer.returnValue(False) - return + return defer.succeed(False) # either no interested services, or no service with an exclusive lock - defer.returnValue(True) + return defer.succeed(True) @defer.inlineCallbacks def _user_can_delete_alias(self, alias, user_id): @@ -276,3 +317,25 @@ class DirectoryHandler(BaseHandler): is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id)) defer.returnValue(is_admin) + + @defer.inlineCallbacks + def edit_published_room_list(self, requester, room_id, visibility): + """Edit the entry of the room in the published room list. + + requester + room_id (str) + visibility (str): "public" or "private" + """ + if requester.is_guest: + raise AuthError(403, "Guests cannot edit the published room list") + + if visibility not in ["public", "private"]: + raise SynapseError(400, "Invalid visibility setting") + + room = yield self.store.get_room(room_id) + if room is None: + raise SynapseError(400, "Unknown room") + + yield self.auth.check_can_change_room_list(room_id, requester.user) + + yield self.store.set_room_is_public(room_id, visibility == "public") diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py new file mode 100644 index 0000000000..fd11935b40 --- /dev/null +++ b/synapse/handlers/e2e_keys.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ujson as json +import logging + +from canonicaljson import encode_canonical_json +from twisted.internet import defer + +from synapse.api.errors import SynapseError, CodeMessageException +from synapse.types import get_domain_from_id +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination + +logger = logging.getLogger(__name__) + + +class E2eKeysHandler(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.federation = hs.get_replication_layer() + self.device_handler = hs.get_device_handler() + self.is_mine_id = hs.is_mine_id + self.clock = hs.get_clock() + + # doesn't really work as part of the generic query API, because the + # query request requires an object POST, but we abuse the + # "query handler" interface. + self.federation.register_query_handler( + "client_keys", self.on_federation_query_client_keys + ) + + @defer.inlineCallbacks + def query_devices(self, query_body, timeout): + """ Handle a device key query from a client + + { + "device_keys": { + "<user_id>": ["<device_id>"] + } + } + -> + { + "device_keys": { + "<user_id>": { + "<device_id>": { + ... + } + } + } + } + """ + device_keys_query = query_body.get("device_keys", {}) + + # separate users by domain. + # make a map from domain to user_id to device_ids + local_query = {} + remote_queries = {} + + for user_id, device_ids in device_keys_query.items(): + if self.is_mine_id(user_id): + local_query[user_id] = device_ids + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_ids + + # do the queries + failures = {} + results = {} + if local_query: + local_result = yield self.query_local_devices(local_query) + for user_id, keys in local_result.items(): + if user_id in local_query: + results[user_id] = keys + + @defer.inlineCallbacks + def do_remote_query(destination): + destination_query = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.query_client_keys( + destination, + {"device_keys": destination_query}, + timeout=timeout + ) + + for user_id, keys in remote_result["device_keys"].items(): + if user_id in destination_query: + results[user_id] = keys + + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(do_remote_query)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "device_keys": results, "failures": failures, + }) + + @defer.inlineCallbacks + def query_local_devices(self, query): + """Get E2E device keys for local users + + Args: + query (dict[string, list[string]|None): map from user_id to a list + of devices to query (None for all devices) + + Returns: + defer.Deferred: (resolves to dict[string, dict[string, dict]]): + map from user_id -> device_id -> device details + """ + local_query = [] + + result_dict = {} + for user_id, device_ids in query.items(): + if not self.is_mine_id(user_id): + logger.warning("Request for keys for non-local user %s", + user_id) + raise SynapseError(400, "Not a user here") + + if not device_ids: + local_query.append((user_id, None)) + else: + for device_id in device_ids: + local_query.append((user_id, device_id)) + + # make sure that each queried user appears in the result dict + result_dict[user_id] = {} + + results = yield self.store.get_e2e_device_keys(local_query) + + # Build the result structure, un-jsonify the results, and add the + # "unsigned" section + for user_id, device_keys in results.items(): + for device_id, device_info in device_keys.items(): + r = json.loads(device_info["key_json"]) + r["unsigned"] = {} + display_name = device_info["device_display_name"] + if display_name is not None: + r["unsigned"]["device_display_name"] = display_name + result_dict[user_id][device_id] = r + + defer.returnValue(result_dict) + + @defer.inlineCallbacks + def on_federation_query_client_keys(self, query_body): + """ Handle a device key query from a federated server + """ + device_keys_query = query_body.get("device_keys", {}) + res = yield self.query_local_devices(device_keys_query) + defer.returnValue({"device_keys": res}) + + @defer.inlineCallbacks + def claim_one_time_keys(self, query, timeout): + local_query = [] + remote_queries = {} + + for user_id, device_keys in query.get("one_time_keys", {}).items(): + if self.is_mine_id(user_id): + for device_id, algorithm in device_keys.items(): + local_query.append((user_id, device_id, algorithm)) + else: + domain = get_domain_from_id(user_id) + remote_queries.setdefault(domain, {})[user_id] = device_keys + + results = yield self.store.claim_e2e_one_time_keys(local_query) + + json_result = {} + failures = {} + for user_id, device_keys in results.items(): + for device_id, keys in device_keys.items(): + for key_id, json_bytes in keys.items(): + json_result.setdefault(user_id, {})[device_id] = { + key_id: json.loads(json_bytes) + } + + @defer.inlineCallbacks + def claim_client_keys(destination): + device_keys = remote_queries[destination] + try: + limiter = yield get_retry_limiter( + destination, self.clock, self.store + ) + with limiter: + remote_result = yield self.federation.claim_client_keys( + destination, + {"one_time_keys": device_keys}, + timeout=timeout + ) + for user_id, keys in remote_result["one_time_keys"].items(): + if user_id in device_keys: + json_result[user_id] = keys + except CodeMessageException as e: + failures[destination] = { + "status": e.code, "message": e.message + } + except NotRetryingDestination as e: + failures[destination] = { + "status": 503, "message": "Not ready for retry", + } + + yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(claim_client_keys)(destination) + for destination in remote_queries + ])) + + defer.returnValue({ + "one_time_keys": json_result, + "failures": failures + }) + + @defer.inlineCallbacks + def upload_keys_for_user(self, user_id, device_id, keys): + time_now = self.clock.time_msec() + + # TODO: Validate the JSON to make sure it has the right keys. + device_keys = keys.get("device_keys", None) + if device_keys: + logger.info( + "Updating device_keys for device %r for user %s at %d", + device_id, user_id, time_now + ) + # TODO: Sign the JSON with the server key + yield self.store.set_e2e_device_keys( + user_id, device_id, time_now, + encode_canonical_json(device_keys) + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + logger.info( + "Adding %d one_time_keys for device %r for user %r at %d", + len(one_time_keys), device_id, user_id, time_now + ) + key_list = [] + for key_id, key_json in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, encode_canonical_json(key_json) + )) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, key_list + ) + + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) + + result = yield self.store.count_e2e_one_time_keys(user_id, device_id) + + defer.returnValue({"one_time_key_counts": result}) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index f25a252523..d3685fb12a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -58,7 +59,7 @@ class EventStreamHandler(BaseHandler): If `only_keys` is not None, events from keys will be sent down. """ auth_user = UserID.from_string(auth_user_id) - presence_handler = self.hs.get_handlers().presence_handler + presence_handler = self.hs.get_presence_handler() context = yield presence_handler.user_syncing( auth_user_id, affect_presence=affect_presence, @@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.store.get_users_in_room(event.room_id) + users = yield self.state.get_current_user_in_room(event.room_id) states = yield presence_handler.get_states( users, as_event=True, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f599e817aa..771ab3bc43 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -26,20 +26,24 @@ from synapse.api.errors import ( from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.events.validator import EventValidator from synapse.util import unwrapFirstError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import ( + PreserveLoggingContext, preserve_fn, preserve_context_over_deferred +) +from synapse.util.metrics import measure_func from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( compute_event_signature, add_hashes_and_signatures, ) -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id from synapse.events.utils import prune_event from synapse.util.retryutils import NotRetryingDestination from synapse.push.action_generator import ActionGenerator +from synapse.util.distributor import user_joined_room from twisted.internet import defer @@ -49,10 +53,6 @@ import logging logger = logging.getLogger(__name__) -def user_joined_room(distributor, user, room_id): - return distributor.fire("user_joined_room", user, room_id) - - class FederationHandler(BaseHandler): """Handles events that originated from federation. Responsible for: @@ -69,10 +69,6 @@ class FederationHandler(BaseHandler): self.hs = hs - self.distributor.observe("user_joined_room", self.user_joined_room) - - self.waiting_for_join_list = {} - self.store = hs.get_datastore() self.replication_layer = hs.get_replication_layer() self.state_handler = hs.get_state_handler() @@ -84,28 +80,14 @@ class FederationHandler(BaseHandler): # When joining a room we need to queue any events for that room up self.room_queues = {} - def handle_new_event(self, event, destinations): - """ Takes in an event from the client to server side, that has already - been authed and handled by the state module, and sends it to any - remote home servers that may be interested. - - Args: - event: The event to send - destinations: A list of destinations to send it to - - Returns: - Deferred: Resolved when it has successfully been queued for - processing. - """ - - return self.replication_layer.send_pdu(event, destinations) - @log_function @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, backfilled, state=None, - auth_chain=None): + def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): """ Called by the ReplicationLayer when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + auth_chain and state are None if we already have the necessary state + and prev_events in the db """ event = pdu @@ -123,17 +105,25 @@ class FederationHandler(BaseHandler): # FIXME (erikj): Awful hack to make the case where we are not currently # in the room work - current_state = None - is_in_room = yield self.auth.check_host_in_room( - event.room_id, - self.server_name - ) - if not is_in_room and not event.internal_metadata.is_outlier(): - logger.debug("Got event for room we're not in.") + # If state and auth_chain are None, then we don't need to do this check + # as we already know we have enough state in the DB to handle this + # event. + if state and auth_chain and not event.internal_metadata.is_outlier(): + is_in_room = yield self.auth.check_host_in_room( + event.room_id, + self.server_name + ) + else: + is_in_room = True + if not is_in_room: + logger.info( + "Got event for room we're not in: %r %r", + event.room_id, event.event_id + ) try: event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) except AuthError as e: raise FederationError( @@ -175,19 +165,13 @@ class FederationHandler(BaseHandler): }) seen_ids.add(e.event_id) - yield self._handle_new_events( - origin, - event_infos, - outliers=True - ) + yield self._handle_new_events(origin, event_infos) try: context, event_stream_id, max_stream_id = yield self._handle_new_event( origin, event, state=state, - backfilled=backfilled, - current_state=current_state, ) except AuthError as e: raise FederationError( @@ -216,32 +200,42 @@ class FederationHandler(BaseHandler): except StoreError: logger.exception("Failed to store room.") - if not backfilled: - extra_users = [] - if event.type == EventTypes.Member: - target_user_id = event.state_key - target_user = UserID.from_string(target_user_id) - extra_users.append(target_user) + extra_users = [] + if event.type == EventTypes.Member: + target_user_id = event.state_key + target_user = UserID.from_string(target_user_id) + extra_users.append(target_user) - with PreserveLoggingContext(): - self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, - extra_users=extra_users - ) + with PreserveLoggingContext(): + self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users + ) if event.type == EventTypes.Member: if event.membership == Membership.JOIN: - prev_state = context.current_state.get((event.type, event.state_key)) - if not prev_state or prev_state.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally - # joined the room. Don't bother if the user is just - # changing their profile info. + # Only fire user_joined_room if the user has acutally + # joined the room. Don't bother if the user is just + # changing their profile info. + newly_joined = True + prev_state_id = context.prev_state_ids.get( + (event.type, event.state_key) + ) + if prev_state_id: + prev_state = yield self.store.get_event( + prev_state_id, allow_none=True, + ) + if prev_state and prev_state.membership == Membership.JOIN: + newly_joined = False + + if newly_joined: user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) + @measure_func("_filter_events_for_server") @defer.inlineCallbacks def _filter_events_for_server(self, server_name, room_id, events): - event_to_state = yield self.store.get_state_for_events( + event_to_state_ids = yield self.store.get_state_ids_for_events( frozenset(e.event_id for e in events), types=( (EventTypes.RoomHistoryVisibility, ""), @@ -249,6 +243,30 @@ class FederationHandler(BaseHandler): ) ) + # We only want to pull out member events that correspond to the + # server's domain. + + def check_match(id): + try: + return server_name == get_domain_from_id(id) + except: + return False + + event_map = yield self.store.get_events([ + e_id for key_to_eid in event_to_state_ids.values() + for key, e_id in key_to_eid + if key[0] != EventTypes.Member or check_match(key[1]) + ]) + + event_to_state = { + e_id: { + key: event_map[inner_e_id] + for key, inner_e_id in key_to_eid.items() + if inner_e_id in event_map + } + for e_id, key_to_eid in event_to_state_ids.items() + } + def redact_disallowed(event, state): if not state: return event @@ -265,7 +283,7 @@ class FederationHandler(BaseHandler): if ev.type != EventTypes.Member: continue try: - domain = UserID.from_string(ev.state_key).domain + domain = get_domain_from_id(ev.state_key) except: continue @@ -290,11 +308,15 @@ class FederationHandler(BaseHandler): @log_function @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities=[]): + def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` + + This will attempt to get more events from the remote. This may return + be successfull and still return no events if the other side has no new + events to offer. """ - if not extremities: - extremities = yield self.store.get_oldest_events_in_room(room_id) + if dest == self.server_name: + raise SynapseError(400, "Can't backfill from self.") events = yield self.replication_layer.backfill( dest, @@ -303,6 +325,16 @@ class FederationHandler(BaseHandler): extremities=extremities, ) + # Don't bother processing events we already have. + seen_events = yield self.store.have_events_in_timeline( + set(e.event_id for e in events) + ) + + events = [e for e in events if e.event_id not in seen_events] + + if not events: + defer.returnValue([]) + event_map = {e.event_id: e for e in events} event_ids = set(e.event_id for e in events) @@ -334,40 +366,73 @@ class FederationHandler(BaseHandler): state_events.update({s.event_id: s for s in state}) events_to_state[e_id] = state - seen_events = yield self.store.have_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - all_events = events + state_events.values() + auth_events.values() required_auth = set( - a_id for event in all_events for a_id, _ in event.auth_events + a_id + for event in events + state_events.values() + auth_events.values() + for a_id, _ in event.auth_events ) - + auth_events.update({ + e_id: event_map[e_id] for e_id in required_auth if e_id in event_map + }) missing_auth = required_auth - set(auth_events) - results = yield defer.gatherResults( - [ - self.replication_layer.get_pdu( - [dest], - event_id, - outlier=True, - timeout=10000, + failed_to_fetch = set() + + # Try and fetch any missing auth events from both DB and remote servers. + # We repeatedly do this until we stop finding new auth events. + while missing_auth - failed_to_fetch: + logger.info("Missing auth for backfill: %r", missing_auth) + ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + auth_events.update(ret_events) + + required_auth.update( + a_id for event in ret_events.values() for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + if missing_auth - failed_to_fetch: + logger.info( + "Fetching missing auth for backfill: %r", + missing_auth - failed_to_fetch ) - for event_id in missing_auth - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results}) + + results = yield preserve_context_over_deferred(defer.gatherResults( + [ + preserve_fn(self.replication_layer.get_pdu)( + [dest], + event_id, + outlier=True, + timeout=10000, + ) + for event_id in missing_auth - failed_to_fetch + ], + consumeErrors=True + )).addErrback(unwrapFirstError) + auth_events.update({a.event_id: a for a in results if a}) + required_auth.update( + a_id + for event in results if event + for a_id, _ in event.auth_events + ) + missing_auth = required_auth - set(auth_events) + + failed_to_fetch = missing_auth - set(auth_events) + + seen_events = yield self.store.have_events( + set(auth_events.keys()) | set(state_events.keys()) + ) ev_infos = [] for a in auth_events.values(): if a.event_id in seen_events: continue + a.internal_metadata.outlier = True ev_infos.append({ "event": a, "auth_events": { (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in a.auth_events + if a_id in auth_events } }) @@ -379,23 +444,27 @@ class FederationHandler(BaseHandler): (auth_events[a_id].type, auth_events[a_id].state_key): auth_events[a_id] for a_id, _ in event_map[e_id].auth_events + if a_id in auth_events } }) + yield self._handle_new_events( + dest, ev_infos, + backfilled=True, + ) + events.sort(key=lambda e: e.depth) for event in events: if event in events_to_state: continue - ev_infos.append({ - "event": event, - }) - - yield self._handle_new_events( - dest, ev_infos, - backfilled=True, - ) + # We store these one at a time since each event depends on the + # previous to work out the state. + # TODO: We can probably do something more clever here. + yield self._handle_new_event( + dest, event, backfilled=True, + ) defer.returnValue(events) @@ -419,6 +488,10 @@ class FederationHandler(BaseHandler): ) max_depth = sorted_extremeties_tuple[0][1] + # We don't want to specify too many extremities as it causes the backfill + # request URI to be too long. + extremities = dict(sorted_extremeties_tuple[:5]) + if current_depth > max_depth: logger.debug( "Not backfilling as we don't need to. %d < %d", @@ -444,7 +517,7 @@ class FederationHandler(BaseHandler): joined_domains = {} for u, d in joined_users: try: - dom = UserID.from_string(u).domain + dom = get_domain_from_id(u) old_d = joined_domains.get(dom) if old_d: joined_domains[dom] = min(d, old_d) @@ -459,7 +532,7 @@ class FederationHandler(BaseHandler): likely_domains = [ domain for domain, depth in curr_domains - if domain is not self.server_name + if domain != self.server_name ] @defer.inlineCallbacks @@ -467,11 +540,15 @@ class FederationHandler(BaseHandler): # TODO: Should we try multiple of these at a time? for dom in domains: try: - events = yield self.backfill( + yield self.backfill( dom, room_id, limit=100, extremities=[e for e in extremities.keys()] ) + # If this succeeded then we probably already have the + # appropriate stuff. + # TODO: We can probably do something more intelligent here. + defer.returnValue(True) except SynapseError as e: logger.info( "Failed to backfill from %s because %s", @@ -497,8 +574,6 @@ class FederationHandler(BaseHandler): ) continue - if events: - defer.returnValue(True) defer.returnValue(False) success = yield try_backfill(likely_domains) @@ -513,12 +588,24 @@ class FederationHandler(BaseHandler): event_ids = list(extremities.keys()) - states = yield defer.gatherResults([ - self.state_handler.resolve_state_groups(room_id, [e]) + states = yield preserve_context_over_deferred(defer.gatherResults([ + preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) for e in event_ids - ]) + ])) states = dict(zip(event_ids, [s[1] for s in states])) + state_map = yield self.store.get_events( + [e_id for ids in states.values() for e_id in ids], + get_prev_content=False + ) + states = { + key: { + k: state_map[e_id] + for k, e_id in state_dict.items() + if e_id in state_map + } for key, state_dict in states.items() + } + for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) @@ -628,7 +715,7 @@ class FederationHandler(BaseHandler): pass event_stream_id, max_stream_id = yield self._persist_auth_tree( - auth_chain, state, event + origin, auth_chain, state, event ) with PreserveLoggingContext(): @@ -647,7 +734,7 @@ class FederationHandler(BaseHandler): continue try: - self.on_receive_pdu(origin, p, backfilled=False) + self.on_receive_pdu(origin, p) except: logger.exception("Couldn't handle pdu") @@ -670,11 +757,18 @@ class FederationHandler(BaseHandler): "state_key": user_id, }) - event, context = yield self._create_new_client_event( - builder=builder, - ) + try: + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( + builder=builder, + ) + except AuthError as e: + logger.warn("Failed to create join %r because %s", event, e) + raise e - self.auth.check(event, auth_events=context.current_state) + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_join_request` + yield self.auth.check_from_context(event, context, do_sig_check=False) defer.returnValue(event) @@ -720,39 +814,15 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield user_joined_room(self.distributor, user, event.room_id) - new_pdu = event - - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.JOIN: - destinations.add( - UserID.from_string(s.state_key).domain - ) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - destinations.discard(origin) - - logger.debug( - "on_send_join_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - - state_ids = [e.event_id for e in context.current_state.values()] + state_ids = context.prev_state_ids.values() auth_chain = yield self.store.get_auth_chain(set( [event.event_id] + state_ids )) + state = yield self.store.get_events(context.prev_state_ids.values()) + defer.returnValue({ - "state": context.current_state.values(), + "state": state.values(), "auth_chain": auth_chain, }) @@ -765,6 +835,7 @@ class FederationHandler(BaseHandler): event = pdu event.internal_metadata.outlier = True + event.internal_metadata.invite_from_remote = True event.signatures.update( compute_event_signature( @@ -779,7 +850,6 @@ class FederationHandler(BaseHandler): event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - backfilled=False, ) target_user = UserID.from_string(event.state_key) @@ -793,13 +863,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id): - origin, event = yield self._make_and_verify_event( - target_hosts, - room_id, - user_id, - "leave" - ) - signed_event = self._sign_event(event) + try: + origin, event = yield self._make_and_verify_event( + target_hosts, + room_id, + user_id, + "leave" + ) + signed_event = self._sign_event(event) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") # Try the host we successfully got a response to /make_join/ # request first. @@ -809,17 +885,22 @@ class FederationHandler(BaseHandler): except ValueError: pass - yield self.replication_layer.send_leave( - target_hosts, - signed_event - ) + try: + yield self.replication_layer.send_leave( + target_hosts, + signed_event + ) + except SynapseError: + raise + except CodeMessageException as e: + logger.warn("Failed to reject invite: %s", e) + raise SynapseError(500, "Failed to reject invite") context = yield self.state_handler.compute_event_context(event) event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, - backfilled=False, ) target_user = UserID.from_string(event.state_key) @@ -889,11 +970,18 @@ class FederationHandler(BaseHandler): "state_key": user_id, }) - event, context = yield self._create_new_client_event( + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( builder=builder, ) - self.auth.check(event, auth_events=context.current_state) + try: + # The remote hasn't signed it yet, obviously. We'll do the full checks + # when we get the event back in `on_send_leave_request` + yield self.auth.check_from_context(event, context, do_sig_check=False) + except AuthError as e: + logger.warn("Failed to create new leave %r because %s", event, e) + raise e defer.returnValue(event) @@ -932,43 +1020,14 @@ class FederationHandler(BaseHandler): event, event_stream_id, max_stream_id, extra_users=extra_users ) - new_pdu = event - - destinations = set() - - for k, s in context.current_state.items(): - try: - if k[0] == EventTypes.Member: - if s.content["membership"] == Membership.LEAVE: - destinations.add( - UserID.from_string(s.state_key).domain - ) - except: - logger.warn( - "Failed to get destination from event %s", s.event_id - ) - - destinations.discard(origin) - - logger.debug( - "on_send_leave_request: Sending event: %s, signatures: %s", - event.event_id, - event.signatures, - ) - - self.replication_layer.send_pdu(new_pdu, destinations) - defer.returnValue(None) @defer.inlineCallbacks - def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): + def get_state_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ yield run_on_reactor() - if do_auth: - in_room = yield self.auth.check_host_in_room(room_id, origin) - if not in_room: - raise AuthError(403, "Host not in room.") - state_groups = yield self.store.get_state_groups( room_id, [event_id] ) @@ -992,19 +1051,50 @@ class FederationHandler(BaseHandler): res = results.values() for event in res: - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + # We sign these again because there was a bug where we + # incorrectly signed things the first time round + if self.hs.is_mine_id(event.event_id): + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) defer.returnValue(res) else: defer.returnValue([]) @defer.inlineCallbacks + def get_state_ids_for_pdu(self, room_id, event_id): + """Returns the state at the event. i.e. not including said event. + """ + yield run_on_reactor() + + state_groups = yield self.store.get_state_groups_ids( + room_id, [event_id] + ) + + if state_groups: + _, state = state_groups.items().pop() + results = state + + event = yield self.store.get_event(event_id) + if event and event.is_state(): + # Get previous state + if "replaces_state" in event.unsigned: + prev_id = event.unsigned["replaces_state"] + if prev_id != event.event_id: + results[(event.type, event.state_key)] = prev_id + else: + del results[(event.type, event.state_key)] + + defer.returnValue(results.values()) + else: + defer.returnValue([]) + + @defer.inlineCallbacks @log_function def on_backfill_request(self, origin, room_id, pdu_list, limit): in_room = yield self.auth.check_host_in_room(room_id, origin) @@ -1036,16 +1126,17 @@ class FederationHandler(BaseHandler): ) if event: - # FIXME: This is a temporary work around where we occasionally - # return events slightly differently than when they were - # originally signed - event.signatures.update( - compute_event_signature( - event, - self.hs.hostname, - self.hs.config.signing_key[0] + if self.hs.is_mine_id(event.event_id): + # FIXME: This is a temporary work around where we occasionally + # return events slightly differently than when they were + # originally signed + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) ) - ) if do_auth: in_room = yield self.auth.check_host_in_room( @@ -1055,6 +1146,12 @@ class FederationHandler(BaseHandler): if not in_room: raise AuthError(403, "Host not in room.") + events = yield self._filter_events_for_server( + origin, event.room_id, [event] + ) + + event = events[0] + defer.returnValue(event) else: defer.returnValue(None) @@ -1063,50 +1160,47 @@ class FederationHandler(BaseHandler): def get_min_depth_for_context(self, context): return self.store.get_min_depth(context) - @log_function - def user_joined_room(self, user, room_id): - waiters = self.waiting_for_join_list.get( - (user.to_string(), room_id), - [] - ) - while waiters: - waiters.pop().callback(None) - @defer.inlineCallbacks @log_function - def _handle_new_event(self, origin, event, state=None, backfilled=False, - current_state=None, auth_events=None): - - outlier = event.internal_metadata.is_outlier() - + def _handle_new_event(self, origin, event, state=None, auth_events=None, + backfilled=False): context = yield self._prep_event( origin, event, state=state, auth_events=auth_events, ) - if not backfilled and not event.internal_metadata.is_outlier(): + if not event.internal_metadata.is_outlier(): action_generator = ActionGenerator(self.hs) yield action_generator.handle_push_actions_for_event( - event, context, self + event, context ) event_stream_id, max_stream_id = yield self.store.persist_event( event, context=context, backfilled=backfilled, - is_new_state=(not outlier and not backfilled), - current_state=current_state, ) + if not backfilled: + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id + ) + defer.returnValue((context, event_stream_id, max_stream_id)) @defer.inlineCallbacks - def _handle_new_events(self, origin, event_infos, backfilled=False, - outliers=False): - contexts = yield defer.gatherResults( + def _handle_new_events(self, origin, event_infos, backfilled=False): + """Creates the appropriate contexts and persists events. The events + should not depend on one another, e.g. this should be used to persist + a bunch of outliers, but not a chunk of individual events that depend + on each other for state calculations. + """ + contexts = yield preserve_context_over_deferred(defer.gatherResults( [ - self._prep_event( + preserve_fn(self._prep_event)( origin, ev_info["event"], state=ev_info.get("state"), @@ -1114,7 +1208,7 @@ class FederationHandler(BaseHandler): ) for ev_info in event_infos ] - ) + )) yield self.store.persist_events( [ @@ -1122,30 +1216,35 @@ class FederationHandler(BaseHandler): for ev_info, context in itertools.izip(event_infos, contexts) ], backfilled=backfilled, - is_new_state=(not outliers and not backfilled), ) @defer.inlineCallbacks - def _persist_auth_tree(self, auth_events, state, event): + def _persist_auth_tree(self, origin, auth_events, state, event): """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event seperately. + Will attempt to fetch missing auth events. + + Args: + origin (str): Where the events came from + auth_events (list) + state (list) + event (Event) + Returns: 2-tuple of (event_stream_id, max_stream_id) from the persist_event call for `event` """ events_to_context = {} for e in itertools.chain(auth_events, state): - ctx = yield self.state_handler.compute_event_context( - e, outlier=True, - ) - events_to_context[e.event_id] = ctx e.internal_metadata.outlier = True + ctx = yield self.state_handler.compute_event_context(e) + events_to_context[e.event_id] = ctx event_map = { e.event_id: e - for e in auth_events + for e in itertools.chain(auth_events, state, [event]) } create_event = None @@ -1154,10 +1253,29 @@ class FederationHandler(BaseHandler): create_event = e break + missing_auth_events = set() + for e in itertools.chain(auth_events, state, [event]): + for e_id, _ in e.auth_events: + if e_id not in event_map: + missing_auth_events.add(e_id) + + for e_id in missing_auth_events: + m_ev = yield self.replication_layer.get_pdu( + [origin], + e_id, + outlier=True, + timeout=10000, + ) + if m_ev and m_ev.event_id == e_id: + event_map[e_id] = m_ev + else: + logger.info("Failed to find auth event %r", e_id) + for e in itertools.chain(auth_events, state, [event]): auth_for_e = { (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] for e_id, _ in e.auth_events + if e_id in event_map } if create_event: auth_for_e[(EventTypes.Create, "")] = create_event @@ -1185,17 +1303,14 @@ class FederationHandler(BaseHandler): (e, events_to_context[e.event_id]) for e in itertools.chain(auth_events, state) ], - is_new_state=False, ) new_event_context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=False, + event, old_state=state ) event_stream_id, max_stream_id = yield self.store.persist_event( event, new_event_context, - backfilled=False, - is_new_state=True, current_state=state, ) @@ -1203,14 +1318,19 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def _prep_event(self, origin, event, state=None, auth_events=None): - outlier = event.internal_metadata.is_outlier() context = yield self.state_handler.compute_event_context( - event, old_state=state, outlier=outlier, + event, old_state=state, ) if not auth_events: - auth_events = context.current_state + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, + ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } # This is a hack to fix some old rooms where the initial join event # didn't reference the create event in its auth events. @@ -1236,8 +1356,7 @@ class FederationHandler(BaseHandler): context.rejected = RejectedReason.AUTH_ERROR if event.type == EventTypes.GuestAccess: - full_context = yield self.store.get_current_state(room_id=event.room_id) - yield self.maybe_kick_guest_users(event, full_context) + yield self.maybe_kick_guest_users(event) defer.returnValue(context) @@ -1305,6 +1424,11 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) event_auth_events = set(e_id for e_id, _ in event.auth_events) + if event.is_state(): + event_key = (event.type, event.state_key) + else: + event_key = None + if event_auth_events - current_state: have_events = yield self.store.have_events( event_auth_events - current_state @@ -1378,9 +1502,9 @@ class FederationHandler(BaseHandler): # Do auth conflict res. logger.info("Different auth: %s", different_auth) - different_events = yield defer.gatherResults( + different_events = yield preserve_context_over_deferred(defer.gatherResults( [ - self.store.get_event( + preserve_fn(self.store.get_event)( d, allow_none=True, allow_rejected=False, @@ -1389,13 +1513,13 @@ class FederationHandler(BaseHandler): if d in have_events and not have_events[d] ], consumeErrors=True - ).addErrback(unwrapFirstError) + )).addErrback(unwrapFirstError) if different_events: local_view = dict(auth_events) remote_view = dict(auth_events) remote_view.update({ - (d.type, d.state_key): d for d in different_events + (d.type, d.state_key): d for d in different_events if d }) new_state, prev_state = self.state_handler.resolve_events( @@ -1408,8 +1532,16 @@ class FederationHandler(BaseHandler): current_state = set(e.event_id for e in auth_events.values()) different_auth = event_auth_events - current_state - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids = dict(context.current_state_ids) + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids = dict(context.prev_state_ids) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() if different_auth and not event.internal_metadata.is_outlier(): logger.info("Different auth after resolution: %s", different_auth) @@ -1430,8 +1562,8 @@ class FederationHandler(BaseHandler): if do_resolution: # 1. Get what we think is the auth chain. - auth_ids = self.auth.compute_auth_events( - event, context.current_state + auth_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids ) local_auth_chain = yield self.store.get_auth_chain(auth_ids) @@ -1487,13 +1619,22 @@ class FederationHandler(BaseHandler): # 4. Look at rejects and their proofs. # TODO. - context.current_state.update(auth_events) - context.state_group = None + context.current_state_ids = dict(context.current_state_ids) + context.current_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + if k != event_key + }) + context.prev_state_ids = dict(context.prev_state_ids) + context.prev_state_ids.update({ + k: a.event_id for k, a in auth_events.items() + }) + context.state_group = self.store.get_next_state_group() try: self.auth.check(event, auth_events=auth_events) - except AuthError: - raise + except AuthError as e: + logger.warn("Failed auth resolution for %r because %s", event, e) + raise e @defer.inlineCallbacks def construct_auth_difference(self, local_auth, remote_auth): @@ -1663,14 +1804,22 @@ class FederationHandler(BaseHandler): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - event, context = yield self._create_new_client_event(builder=builder) + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( + builder=builder + ) event, context = yield self.add_display_name_to_third_party_invite( event_dict, event, context ) - self.auth.check(event, context.current_state) - yield self._check_signature(event, auth_events=context.current_state) + try: + yield self.auth.check_from_context(event, context) + except AuthError as e: + logger.warn("Denying new third party invite %r because %s", event, e) + raise e + + yield self._check_signature(event, context) member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(None, event, context) else: @@ -1686,7 +1835,8 @@ class FederationHandler(BaseHandler): def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): builder = self.event_builder_factory.new(event_dict) - event, context = yield self._create_new_client_event( + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event( builder=builder, ) @@ -1694,8 +1844,12 @@ class FederationHandler(BaseHandler): event_dict, event, context ) - self.auth.check(event, auth_events=context.current_state) - yield self._check_signature(event, auth_events=context.current_state) + try: + self.auth.check_from_context(event, context) + except AuthError as e: + logger.warn("Denying third party invite %r because %s", event, e) + raise e + yield self._check_signature(event, context) returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. @@ -1709,41 +1863,56 @@ class FederationHandler(BaseHandler): EventTypes.ThirdPartyInvite, event.content["third_party_invite"]["signed"]["token"] ) - original_invite = context.current_state.get(key) - if not original_invite: + original_invite = None + original_invite_id = context.prev_state_ids.get(key) + if original_invite_id: + original_invite = yield self.store.get_event( + original_invite_id, allow_none=True + ) + if original_invite: + display_name = original_invite.content["display_name"] + event_dict["content"]["third_party_invite"]["display_name"] = display_name + else: logger.info( - "Could not find invite event for third_party_invite - " - "discarding: %s" % (event_dict,) + "Could not find invite event for third_party_invite: %r", + event_dict ) - return + # We don't discard here as this is not the appropriate place to do + # auth checks. If we need the invite and don't have it then the + # auth check code will explode appropriately. - display_name = original_invite.content["display_name"] - event_dict["content"]["third_party_invite"]["display_name"] = display_name builder = self.event_builder_factory.new(event_dict) EventValidator().validate_new(builder) - event, context = yield self._create_new_client_event(builder=builder) + message_handler = self.hs.get_handlers().message_handler + event, context = yield message_handler._create_new_client_event(builder=builder) defer.returnValue((event, context)) @defer.inlineCallbacks - def _check_signature(self, event, auth_events): + def _check_signature(self, event, context): """ Checks that the signature in the event is consistent with its invite. - :param event (Event): The m.room.member event to check - :param auth_events (dict<(event type, state_key), event>) - :raises - AuthError if signature didn't match any keys, or key has been + Args: + event (Event): The m.room.member event to check + context (EventContext): + + Raises: + AuthError: if signature didn't match any keys, or key has been revoked, - SynapseError if a transient error meant a key couldn't be checked + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ signed = event.content["third_party_invite"]["signed"] token = signed["token"] - invite_event = auth_events.get( + invite_event_id = context.prev_state_ids.get( (EventTypes.ThirdPartyInvite, token,) ) + invite_event = None + if invite_event_id: + invite_event = yield self.store.get_event(invite_event_id, allow_none=True) + if not invite_event: raise AuthError(403, "Could not find invite") @@ -1776,12 +1945,13 @@ class FederationHandler(BaseHandler): """ Checks whether public_key has been revoked. - :param public_key (str): base-64 encoded public key. - :param url (str): Key revocation URL. + Args: + public_key (str): base-64 encoded public key. + url (str): Key revocation URL. - :raises - AuthError if they key has been revoked. - SynapseError if a transient error meant a key couldn't be checked + Raises: + AuthError: if they key has been revoked. + SynapseError: if a transient error meant a key couldn't be checked for revocation. """ try: diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 656ce124f9..559e5d5a71 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,7 +21,7 @@ from synapse.api.errors import ( ) from ._base import BaseHandler from synapse.util.async import run_on_reactor -from synapse.api.errors import SynapseError +from synapse.api.errors import SynapseError, Codes import json import logging @@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler): hs.config.use_insecure_ssl_client_just_for_testing_do_not_use ) + def _should_trust_id_server(self, id_server): + if id_server not in self.trusted_id_servers: + if self.trust_any_id_server_just_for_testing_do_not_use: + logger.warn( + "Trusting untrustworthy ID server %r even though it isn't" + " in the trusted id list for testing because" + " 'use_insecure_ssl_client_just_for_testing_do_not_use'" + " is set in the config", + id_server, + ) + else: + return False + return True + @defer.inlineCallbacks def threepid_from_creds(self, creds): yield run_on_reactor() @@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler): else: raise SynapseError(400, "No client_secret in creds") - if id_server not in self.trusted_id_servers: - if self.trust_any_id_server_just_for_testing_do_not_use: - logger.warn( - "Trusting untrustworthy ID server %r even though it isn't" - " in the trusted id list for testing because" - " 'use_insecure_ssl_client_just_for_testing_do_not_use'" - " is set in the config", - id_server, - ) - else: - logger.warn('%s is not a trusted ID server: rejecting 3pid ' + - 'credentials', id_server) - defer.returnValue(None) + if not self._should_trust_id_server(id_server): + logger.warn( + '%s is not a trusted ID server: rejecting 3pid ' + + 'credentials', id_server + ) + defer.returnValue(None) data = {} try: @@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): yield run_on_reactor() + if not self._should_trust_id_server(id_server): + raise SynapseError( + 400, "Untrusted ID server '%s'" % id_server, + Codes.SERVER_NOT_TRUSTED + ) + params = { 'email': email, 'client_secret': client_secret, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py new file mode 100644 index 0000000000..e0ade4c164 --- /dev/null +++ b/synapse/handlers/initial_sync.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import AuthError, Codes +from synapse.events.utils import serialize_event +from synapse.events.validator import EventValidator +from synapse.streams.config import PaginationConfig +from synapse.types import ( + UserID, StreamToken, +) +from synapse.util import unwrapFirstError +from synapse.util.async import concurrently_execute +from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.visibility import filter_events_for_client + +from ._base import BaseHandler + +import logging + + +logger = logging.getLogger(__name__) + + +class InitialSyncHandler(BaseHandler): + def __init__(self, hs): + super(InitialSyncHandler, self).__init__(hs) + self.hs = hs + self.state = hs.get_state_handler() + self.clock = hs.get_clock() + self.validator = EventValidator() + self.snapshot_cache = SnapshotCache() + + def snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + """Retrieve a snapshot of all rooms the user is invited or has joined. + + This snapshot may include messages for all rooms where the user is + joined, depending on the pagination config. + + Args: + user_id (str): The ID of the user making the request. + pagin_config (synapse.api.streams.PaginationConfig): The pagination + config used to determine how many messages *PER ROOM* to return. + as_client_event (bool): True to get events in client-server format. + include_archived (bool): True to get rooms that the user has left + Returns: + A list of dicts with "room_id" and "membership" keys for all rooms + the user is currently invited or joined in on. Rooms where the user + is joined on, may return a "messages" key with messages, depending + on the specified PaginationConfig. + """ + key = ( + user_id, + pagin_config.from_token, + pagin_config.to_token, + pagin_config.direction, + pagin_config.limit, + as_client_event, + include_archived, + ) + now_ms = self.clock.time_msec() + result = self.snapshot_cache.get(now_ms, key) + if result is not None: + return result + + return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( + user_id, pagin_config, as_client_event, include_archived + )) + + @defer.inlineCallbacks + def _snapshot_all_rooms(self, user_id=None, pagin_config=None, + as_client_event=True, include_archived=False): + + memberships = [Membership.INVITE, Membership.JOIN] + if include_archived: + memberships.append(Membership.LEAVE) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, membership_list=memberships + ) + + user = UserID.from_string(user_id) + + rooms_ret = [] + + now_token = yield self.hs.get_event_sources().get_current_token() + + presence_stream = self.hs.get_event_sources().sources["presence"] + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user, pagination_config.get_source_config("presence"), None + ) + + receipt_stream = self.hs.get_event_sources().sources["receipt"] + receipt, _ = yield receipt_stream.get_pagination_rows( + user, pagination_config.get_source_config("receipt"), None + ) + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + + public_room_ids = yield self.store.get_public_room_ids() + + limit = pagin_config.limit + if limit is None: + limit = 10 + + @defer.inlineCallbacks + def handle_room(event): + d = { + "room_id": event.room_id, + "membership": event.membership, + "visibility": ( + "public" if event.room_id in public_room_ids + else "private" + ), + } + + if event.membership == Membership.INVITE: + time_now = self.clock.time_msec() + d["inviter"] = event.sender + + invite_event = yield self.store.get_event(event.event_id) + d["invite"] = serialize_event(invite_event, time_now, as_client_event) + + rooms_ret.append(d) + + if event.membership not in (Membership.JOIN, Membership.LEAVE): + return + + try: + if event.membership == Membership.JOIN: + room_end_token = now_token.room_key + deferred_room_state = self.state_handler.get_current_state( + event.room_id + ) + elif event.membership == Membership.LEAVE: + room_end_token = "s%d" % (event.stream_ordering,) + deferred_room_state = self.store.get_state_for_events( + [event.event_id], None + ) + deferred_room_state.addCallback( + lambda states: states[event.event_id] + ) + + (messages, token), current_state = yield preserve_context_over_deferred( + defer.gatherResults( + [ + preserve_fn(self.store.get_recent_events_for_room)( + event.room_id, + limit=limit, + end_token=room_end_token, + ), + deferred_room_state, + ] + ) + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + time_now = self.clock.time_msec() + + d["messages"] = { + "chunk": [ + serialize_event(m, time_now, as_client_event) + for m in messages + ], + "start": start_token.to_string(), + "end": end_token.to_string(), + } + + d["state"] = [ + serialize_event(c, time_now, as_client_event) + for c in current_state.values() + ] + + account_data_events = [] + tags = tags_by_room.get(event.room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events + except: + logger.exception("Failed to get snapshot") + + yield concurrently_execute(handle_room, room_list, 10) + + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + ret = { + "rooms": rooms_ret, + "presence": presence, + "account_data": account_data_events, + "receipts": receipt, + "end": now_token.to_string(), + } + + defer.returnValue(ret) + + @defer.inlineCallbacks + def room_initial_sync(self, requester, room_id, pagin_config=None): + """Capture the a snapshot of a room. If user is currently a member of + the room this will be what is currently in the room. If the user left + the room this will be what was in the room when they left. + + Args: + requester(Requester): The user to get a snapshot for. + room_id(str): The room to get a snapshot of. + pagin_config(synapse.streams.config.PaginationConfig): + The pagination config used to determine how many messages to + return. + Raises: + AuthError if the user wasn't in the room. + Returns: + A JSON serialisable dict with the snapshot of the room. + """ + + user_id = requester.user.to_string() + + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id, + ) + is_peeking = member_event_id is None + + if membership == Membership.JOIN: + result = yield self._room_initial_sync_joined( + user_id, room_id, pagin_config, membership, is_peeking + ) + elif membership == Membership.LEAVE: + result = yield self._room_initial_sync_parted( + user_id, room_id, pagin_config, membership, member_event_id, is_peeking + ) + + account_data_events = [] + tags = yield self.store.get_tags_for_room(user_id, room_id) + if tags: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events + + defer.returnValue(result) + + @defer.inlineCallbacks + def _room_initial_sync_parted(self, user_id, room_id, pagin_config, + membership, member_event_id, is_peeking): + room_state = yield self.store.get_state_for_events( + [member_event_id], None + ) + + room_state = room_state[member_event_id] + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + stream_token = yield self.store.get_stream_token_for_event( + member_event_id + ) + + messages, token = yield self.store.get_recent_events_for_room( + room_id, + limit=limit, + end_token=stream_token + ) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking + ) + + start_token = StreamToken.START.copy_and_replace("room_key", token[0]) + end_token = StreamToken.START.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + defer.returnValue({ + "membership": membership, + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": [serialize_event(s, time_now) for s in room_state.values()], + "presence": [], + "receipts": [], + }) + + @defer.inlineCallbacks + def _room_initial_sync_joined(self, user_id, room_id, pagin_config, + membership, is_peeking): + current_state = yield self.state.get_current_state( + room_id=room_id, + ) + + # TODO: These concurrently + time_now = self.clock.time_msec() + state = [ + serialize_event(x, time_now) + for x in current_state.values() + ] + + now_token = yield self.hs.get_event_sources().get_current_token() + + limit = pagin_config.limit if pagin_config else None + if limit is None: + limit = 10 + + room_members = [ + m for m in current_state.values() + if m.type == EventTypes.Member + and m.content["membership"] == Membership.JOIN + ] + + presence_handler = self.hs.get_presence_handler() + + @defer.inlineCallbacks + def get_presence(): + states = yield presence_handler.get_states( + [m.user_id for m in room_members], + as_event=True, + ) + + defer.returnValue(states) + + @defer.inlineCallbacks + def get_receipts(): + receipts = yield self.store.get_linearized_receipts_for_room( + room_id, + to_key=now_token.receipt_key, + ) + if not receipts: + receipts = [] + defer.returnValue(receipts) + + presence, receipts, (messages, token) = yield defer.gatherResults( + [ + preserve_fn(get_presence)(), + preserve_fn(get_receipts)(), + preserve_fn(self.store.get_recent_events_for_room)( + room_id, + limit=limit, + end_token=now_token.room_key, + ) + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + messages = yield filter_events_for_client( + self.store, user_id, messages, is_peeking=is_peeking, + ) + + start_token = now_token.copy_and_replace("room_key", token[0]) + end_token = now_token.copy_and_replace("room_key", token[1]) + + time_now = self.clock.time_msec() + + ret = { + "room_id": room_id, + "messages": { + "chunk": [serialize_event(m, time_now) for m in messages], + "start": start_token.to_string(), + "end": end_token.to_string(), + }, + "state": state, + "presence": presence, + "receipts": receipts, + } + if not is_peeking: + ret["membership"] = membership + + defer.returnValue(ret) + + @defer.inlineCallbacks + def _check_in_room_or_world_readable(self, room_id, user_id): + try: + # check_user_was_in_room will return the most recent membership + # event for the user if: + # * The user is a non-guest user, and was ever in the room + # * The user is a guest user, and has joined the room + # else it will throw. + member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + defer.returnValue((member_event.membership, member_event.event_id)) + return + except AuthError: + visibility = yield self.state_handler.get_current_state( + room_id, EventTypes.RoomHistoryVisibility, "" + ) + if ( + visibility and + visibility.content["history_visibility"] == "world_readable" + ): + defer.returnValue((Membership.JOIN, None)) + return + raise AuthError( + 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN + ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5c50c611ba..fd09397226 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,27 +16,29 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.streams.config import PaginationConfig +from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator -from synapse.util import unwrapFirstError -from synapse.util.caches.snapshot_cache import SnapshotCache -from synapse.types import UserID, RoomStreamToken, StreamToken +from synapse.push.action_generator import ActionGenerator +from synapse.types import ( + UserID, RoomAlias, RoomStreamToken, +) +from synapse.util.async import run_on_reactor, ReadWriteLock +from synapse.util.logcontext import preserve_fn +from synapse.util.metrics import measure_func +from synapse.visibility import filter_events_for_client from ._base import BaseHandler from canonicaljson import encode_canonical_json import logging +import random logger = logging.getLogger(__name__) -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class MessageHandler(BaseHandler): def __init__(self, hs): @@ -45,40 +47,24 @@ class MessageHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() - @defer.inlineCallbacks - def get_message(self, msg_id=None, room_id=None, sender_id=None, - user_id=None): - """ Retrieve a message. + self.pagination_lock = ReadWriteLock() - Args: - msg_id (str): The message ID to obtain. - room_id (str): The room where the message resides. - sender_id (str): The user ID of the user who sent the message. - user_id (str): The user ID of the user making this request. - Returns: - The message, or None if no message exists. - Raises: - SynapseError if something went wrong. - """ - yield self.auth.check_joined_room(room_id, user_id) + @defer.inlineCallbacks + def purge_history(self, room_id, event_id): + event = yield self.store.get_event(event_id) - # Pull out the message from the db -# msg = yield self.store.get_message( -# room_id=room_id, -# msg_id=msg_id, -# user_id=sender_id -# ) + if event.room_id != room_id: + raise SynapseError(400, "Event is for wrong room.") - # TODO (erikj): Once we work out the correct c-s api we need to think - # on how to do this. + depth = event.depth - defer.returnValue(None) + with (yield self.pagination_lock.write(room_id)): + yield self.store.delete_old_state(room_id, depth) @defer.inlineCallbacks def get_messages(self, requester, room_id=None, pagin_config=None, - as_client_event=True): + as_client_event=True, event_filter=None): """Get messages in a room. Args: @@ -87,18 +73,18 @@ class MessageHandler(BaseHandler): pagin_config (synapse.api.streams.PaginationConfig): The pagination config rules to apply, if any. as_client_event (bool): True to get events in client-server format. + event_filter (Filter): Filter to apply to results or None Returns: dict: Pagination API results """ user_id = requester.user.to_string() - data_source = self.hs.get_event_sources().sources["room"] if pagin_config.from_token: room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token( - direction='b' + yield self.hs.get_event_sources().get_current_token_for_room( + room_id=room_id ) ) room_token = pagin_config.from_token.room_key @@ -111,42 +97,48 @@ class MessageHandler(BaseHandler): source_config = pagin_config.get_source_config("room") - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id - ) + with (yield self.pagination_lock.read(room_id)): + membership, member_event_id = yield self._check_in_room_or_world_readable( + room_id, user_id + ) - if source_config.direction == 'b': - # if we're going backwards, we might need to backfill. This - # requires that we have a topo token. - if room_token.topological: - max_topo = room_token.topological - else: - max_topo = yield self.store.get_max_topological_token_for_stream_and_room( - room_id, room_token.stream - ) + if source_config.direction == 'b': + # if we're going backwards, we might need to backfill. This + # requires that we have a topo token. + if room_token.topological: + max_topo = room_token.topological + else: + max_topo = yield self.store.get_max_topological_token( + room_id, room_token.stream + ) - if membership == Membership.LEAVE: - # If they have left the room then clamp the token to be before - # they left the room, to save the effort of loading from the - # database. - leave_token = yield self.store.get_topological_token_for_event( - member_event_id + if membership == Membership.LEAVE: + # If they have left the room then clamp the token to be before + # they left the room, to save the effort of loading from the + # database. + leave_token = yield self.store.get_topological_token_for_event( + member_event_id + ) + leave_token = RoomStreamToken.parse(leave_token) + if leave_token.topological < max_topo: + source_config.from_key = str(leave_token) + + yield self.hs.get_handlers().federation_handler.maybe_backfill( + room_id, max_topo ) - leave_token = RoomStreamToken.parse(leave_token) - if leave_token.topological < max_topo: - source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, max_topo + events, next_key = yield self.store.paginate_room_events( + room_id=room_id, + from_key=source_config.from_key, + to_key=source_config.to_key, + direction=source_config.direction, + limit=source_config.limit, + event_filter=event_filter, ) - events, next_key = yield data_source.get_pagination_rows( - requester.user, source_config, room_id - ) - - next_token = pagin_config.from_token.copy_and_replace( - "room_key", next_key - ) + next_token = pagin_config.from_token.copy_and_replace( + "room_key", next_key + ) if not events: defer.returnValue({ @@ -155,7 +147,11 @@ class MessageHandler(BaseHandler): "end": next_token.to_string(), }) - events = yield self._filter_events_for_client( + if event_filter: + events = event_filter.filter(events) + + events = yield filter_events_for_client( + self.store, user_id, events, is_peeking=(member_event_id is None), @@ -175,7 +171,7 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_event(self, event_dict, token_id=None, txn_id=None): + def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): """ Given a dict from a client, create a new event. @@ -186,6 +182,9 @@ class MessageHandler(BaseHandler): Args: event_dict (dict): An entire event + token_id (str) + txn_id (str) + prev_event_ids (list): The prev event ids to use when creating the event Returns: Tuple of created event (FrozenEvent), Context @@ -198,12 +197,8 @@ class MessageHandler(BaseHandler): membership = builder.content.get("membership", None) target = UserID.from_string(builder.state_key) - if membership == Membership.JOIN: + if membership in {Membership.JOIN, Membership.INVITE}: # If event doesn't include a display name, add one. - yield collect_presencelike_data( - self.distributor, target, builder.content - ) - elif membership == Membership.INVITE: profile = self.hs.get_handlers().profile_handler content = builder.content @@ -224,6 +219,7 @@ class MessageHandler(BaseHandler): event, context = yield self._create_new_client_event( builder=builder, + prev_event_ids=prev_event_ids, ) defer.returnValue((event, context)) @@ -244,12 +240,27 @@ class MessageHandler(BaseHandler): "Tried to send member event through non-member codepath" ) + # We check here if we are currently being rate limited, so that we + # don't do unnecessary work. We check again just before we actually + # send the event. + time_now = self.clock.time() + allowed, time_allowed = self.ratelimiter.send_message( + event.sender, time_now, + msg_rate_hz=self.hs.config.rc_messages_per_second, + burst_count=self.hs.config.rc_message_burst_count, + update=False, + ) + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,) if event.is_state(): - prev_state = self.deduplicate_state_event(event, context) + prev_state = yield self.deduplicate_state_event(event, context) if prev_state is not None: defer.returnValue(prev_state) @@ -261,9 +272,10 @@ class MessageHandler(BaseHandler): ) if event.type == EventTypes.Message: - presence = self.hs.get_handlers().presence_handler + presence = self.hs.get_presence_handler() yield presence.bump_presence_active_time(user) + @defer.inlineCallbacks def deduplicate_state_event(self, event, context): """ Checks whether event is in the latest resolved state in context. @@ -271,13 +283,17 @@ class MessageHandler(BaseHandler): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_event = context.current_state.get((event.type, event.state_key)) + prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) + prev_event = yield self.store.get_event(prev_event_id, allow_none=True) + if not prev_event: + return + if prev_event and event.user_id == prev_event.user_id: prev_content = encode_canonical_json(prev_event.content) next_content = encode_canonical_json(event.content) if prev_content == next_content: - return prev_event - return None + defer.returnValue(prev_event) + return @defer.inlineCallbacks def create_and_send_nonmember_event( @@ -388,378 +404,210 @@ class MessageHandler(BaseHandler): [serialize_event(c, now) for c in room_state.values()] ) - def snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - """Retrieve a snapshot of all rooms the user is invited or has joined. - - This snapshot may include messages for all rooms where the user is - joined, depending on the pagination config. - - Args: - user_id (str): The ID of the user making the request. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config used to determine how many messages *PER ROOM* to return. - as_client_event (bool): True to get events in client-server format. - include_archived (bool): True to get rooms that the user has left - Returns: - A list of dicts with "room_id" and "membership" keys for all rooms - the user is currently invited or joined in on. Rooms where the user - is joined on, may return a "messages" key with messages, depending - on the specified PaginationConfig. - """ - key = ( - user_id, - pagin_config.from_token, - pagin_config.to_token, - pagin_config.direction, - pagin_config.limit, - as_client_event, - include_archived, - ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - - return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - )) - + @measure_func("_create_new_client_event") @defer.inlineCallbacks - def _snapshot_all_rooms(self, user_id=None, pagin_config=None, - as_client_event=True, include_archived=False): - - memberships = [Membership.INVITE, Membership.JOIN] - if include_archived: - memberships.append(Membership.LEAVE) + def _create_new_client_event(self, builder, prev_event_ids=None): + if prev_event_ids: + prev_events = yield self.store.add_event_hashes(prev_event_ids) + prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) + depth = prev_max_depth + 1 + else: + latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room( + builder.room_id, + ) - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=user_id, membership_list=memberships - ) + # We want to limit the max number of prev events we point to in our + # new event + if len(latest_ret) > 10: + # Sort by reverse depth, so we point to the most recent. + latest_ret.sort(key=lambda a: -a[2]) + new_latest_ret = latest_ret[:5] + + # We also randomly point to some of the older events, to make + # sure that we don't completely ignore the older events. + if latest_ret[5:]: + sample_size = min(5, len(latest_ret[5:])) + new_latest_ret.extend(random.sample(latest_ret[5:], sample_size)) + latest_ret = new_latest_ret + + if latest_ret: + depth = max([d for _, _, d in latest_ret]) + 1 + else: + depth = 1 - user = UserID.from_string(user_id) + prev_events = [ + (event_id, prev_hashes) + for event_id, prev_hashes, _ in latest_ret + ] - rooms_ret = [] + builder.prev_events = prev_events + builder.depth = depth - now_token = yield self.hs.get_event_sources().get_current_token() + state_handler = self.state_handler - presence_stream = self.hs.get_event_sources().sources["presence"] - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user, pagination_config.get_source_config("presence"), None - ) + context = yield state_handler.compute_event_context(builder) - receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( - user, pagination_config.get_source_config("receipt"), None - ) + if builder.is_state(): + builder.prev_state = yield self.store.add_event_hashes( + context.prev_state_events + ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + yield self.auth.add_auth_events(builder, context) - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user(user_id) + signing_key = self.hs.config.signing_key[0] + add_hashes_and_signatures( + builder, self.server_name, signing_key ) - public_room_ids = yield self.store.get_public_room_ids() + event = builder.build() - limit = pagin_config.limit - if limit is None: - limit = 10 + logger.debug( + "Created event %s with state: %s", + event.event_id, context.prev_state_ids, + ) - @defer.inlineCallbacks - def handle_room(event): - d = { - "room_id": event.room_id, - "membership": event.membership, - "visibility": ( - "public" if event.room_id in public_room_ids - else "private" - ), - } + defer.returnValue( + (event, context,) + ) - if event.membership == Membership.INVITE: - time_now = self.clock.time_msec() - d["inviter"] = event.sender + @measure_func("handle_new_client_event") + @defer.inlineCallbacks + def handle_new_client_event( + self, + requester, + event, + context, + ratelimit=True, + extra_users=[] + ): + # We now need to go and hit out to wherever we need to hit out to. - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = serialize_event(invite_event, time_now, as_client_event) + if ratelimit: + self.ratelimit(requester) - rooms_ret.append(d) + try: + yield self.auth.check_from_context(event, context) + except AuthError as err: + logger.warn("Denying new event %r because %s", event, err) + raise err + + yield self.maybe_kick_guest_users(event, context) + + if event.type == EventTypes.CanonicalAlias: + # Check the alias is acually valid (at this time at least) + room_alias_str = event.content.get("alias", None) + if room_alias_str: + room_alias = RoomAlias.from_string(room_alias_str) + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if mapping["room_id"] != event.room_id: + raise SynapseError( + 400, + "Room alias %s does not point to the room" % ( + room_alias_str, + ) + ) - if event.membership not in (Membership.JOIN, Membership.LEAVE): - return + federation_handler = self.hs.get_handlers().federation_handler - try: - if event.membership == Membership.JOIN: - room_end_token = now_token.room_key - deferred_room_state = self.state_handler.get_current_state( - event.room_id - ) - elif event.membership == Membership.LEAVE: - room_end_token = "s%d" % (event.stream_ordering,) - deferred_room_state = self.store.get_state_for_events( - [event.event_id], None - ) - deferred_room_state.addCallback( - lambda states: states[event.event_id] + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.INVITE: + def is_inviter_member_event(e): + return ( + e.type == EventTypes.Member and + e.sender == event.sender ) - (messages, token), current_state = yield defer.gatherResults( - [ - self.store.get_recent_events_for_room( - event.room_id, - limit=limit, - end_token=room_end_token, - ), - deferred_room_state, - ] - ).addErrback(unwrapFirstError) - - messages = yield self._filter_events_for_client( - user_id, messages - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - time_now = self.clock.time_msec() - - d["messages"] = { - "chunk": [ - serialize_event(m, time_now, as_client_event) - for m in messages - ], - "start": start_token.to_string(), - "end": end_token.to_string(), - } - - d["state"] = [ - serialize_event(c, time_now, as_client_event) - for c in current_state.values() + state_to_include_ids = [ + e_id + for k, e_id in context.current_state_ids.items() + if k[0] in self.hs.config.room_invite_state_types + or k[0] == EventTypes.Member and k[1] == event.sender ] - account_data_events = [] - tags = tags_by_room.get(event.room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(event.room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - d["account_data"] = account_data_events - except: - logger.exception("Failed to get snapshot") - - # Only do N rooms at once - n = 5 - d_list = [handle_room(e) for e in room_list] - for i in range(0, len(d_list), n): - yield defer.gatherResults( - d_list[i:i + n], - consumeErrors=True - ).addErrback(unwrapFirstError) - - account_data_events = [] - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - ret = { - "rooms": rooms_ret, - "presence": presence, - "account_data": account_data_events, - "receipts": receipt, - "end": now_token.to_string(), - } + state_to_include = yield self.store.get_events(state_to_include_ids) - defer.returnValue(ret) + event.unsigned["invite_room_state"] = [ + { + "type": e.type, + "state_key": e.state_key, + "content": e.content, + "sender": e.sender, + } + for e in state_to_include.values() + ] - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): - """Capture the a snapshot of a room. If user is currently a member of - the room this will be what is currently in the room. If the user left - the room this will be what was in the room when they left. + invitee = UserID.from_string(event.state_key) + if not self.hs.is_mine(invitee): + # TODO: Can we add signature from remote server in a nicer + # way? If we have been invited by a remote server, we need + # to get them to sign the event. - Args: - requester(Requester): The user to get a snapshot for. - room_id(str): The room to get a snapshot of. - pagin_config(synapse.streams.config.PaginationConfig): - The pagination config used to determine how many messages to - return. - Raises: - AuthError if the user wasn't in the room. - Returns: - A JSON serialisable dict with the snapshot of the room. - """ + returned_invite = yield federation_handler.send_invite( + invitee.domain, + event, + ) - user_id = requester.user.to_string() + event.unsigned.pop("room_state", None) - membership, member_event_id = yield self._check_in_room_or_world_readable( - room_id, user_id, - ) - is_peeking = member_event_id is None + # TODO: Make sure the signatures actually are correct. + event.signatures.update( + returned_invite.signatures + ) - if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( - user_id, room_id, pagin_config, membership, is_peeking - ) - elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( - user_id, room_id, pagin_config, membership, member_event_id, is_peeking + if event.type == EventTypes.Redaction: + auth_events_ids = yield self.auth.compute_auth_events( + event, context.prev_state_ids, for_verification=True, ) + auth_events = yield self.store.get_events(auth_events_ids) + auth_events = { + (e.type, e.state_key): e for e in auth_events.values() + } + if self.auth.check_redaction(event, auth_events=auth_events): + original_event = yield self.store.get_event( + event.redacts, + check_redacted=False, + get_prev_content=False, + allow_rejected=False, + allow_none=False + ) + if event.user_id != original_event.user_id: + raise AuthError( + 403, + "You don't have permission to redact events" + ) - account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) - if tags: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = yield self.store.get_account_data_for_room(user_id, room_id) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - result["account_data"] = account_data_events - - defer.returnValue(result) - - @defer.inlineCallbacks - def _room_initial_sync_parted(self, user_id, room_id, pagin_config, - membership, member_event_id, is_peeking): - room_state = yield self.store.get_state_for_events( - [member_event_id], None - ) - - room_state = room_state[member_event_id] - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 + if event.type == EventTypes.Create and context.prev_state_ids: + raise AuthError( + 403, + "Changing the room create event is forbidden", + ) - stream_token = yield self.store.get_stream_token_for_event( - member_event_id + action_generator = ActionGenerator(self.hs) + yield action_generator.handle_push_actions_for_event( + event, context ) - messages, token = yield self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=stream_token + (event_stream_id, max_stream_id) = yield self.store.persist_event( + event, context=context ) - messages = yield self._filter_events_for_client( - user_id, messages, is_peeking=is_peeking + # this intentionally does not yield: we don't care about the result + # and don't need to wait for it. + preserve_fn(self.hs.get_pusherpool().on_new_notifications)( + event_stream_id, max_stream_id ) - start_token = StreamToken.START.copy_and_replace("room_key", token[0]) - end_token = StreamToken.START.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - defer.returnValue({ - "membership": membership, - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": [serialize_event(s, time_now) for s in room_state.values()], - "presence": [], - "receipts": [], - }) - - @defer.inlineCallbacks - def _room_initial_sync_joined(self, user_id, room_id, pagin_config, - membership, is_peeking): - current_state = yield self.state.get_current_state( - room_id=room_id, - ) - - # TODO: These concurrently - time_now = self.clock.time_msec() - state = [ - serialize_event(x, time_now) - for x in current_state.values() - ] - - now_token = yield self.hs.get_event_sources().get_current_token() - - limit = pagin_config.limit if pagin_config else None - if limit is None: - limit = 10 - - room_members = [ - m for m in current_state.values() - if m.type == EventTypes.Member - and m.content["membership"] == Membership.JOIN - ] - - presence_handler = self.hs.get_handlers().presence_handler - @defer.inlineCallbacks - def get_presence(): - states = yield presence_handler.get_states( - [m.user_id for m in room_members], - as_event=True, + def _notify(): + yield run_on_reactor() + yield self.notifier.on_new_room_event( + event, event_stream_id, max_stream_id, + extra_users=extra_users ) - defer.returnValue(states) - - @defer.inlineCallbacks - def get_receipts(): - receipts_handler = self.hs.get_handlers().receipts_handler - receipts = yield receipts_handler.get_receipts_for_room( - room_id, - now_token.receipt_key - ) - defer.returnValue(receipts) - - presence, receipts, (messages, token) = yield defer.gatherResults( - [ - get_presence(), - get_receipts(), - self.store.get_recent_events_for_room( - room_id, - limit=limit, - end_token=now_token.room_key, - ) - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - - messages = yield self._filter_events_for_client( - user_id, messages, is_peeking=is_peeking, - ) - - start_token = now_token.copy_and_replace("room_key", token[0]) - end_token = now_token.copy_and_replace("room_key", token[1]) - - time_now = self.clock.time_msec() - - ret = { - "room_id": room_id, - "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], - "start": start_token.to_string(), - "end": end_token.to_string(), - }, - "state": state, - "presence": presence, - "receipts": receipts, - } - if not is_peeking: - ret["membership"] = membership + preserve_fn(_notify)() - defer.returnValue(ret) + # If invite, remove room_state from unsigned before sending. + event.unsigned.pop("invite_room_state", None) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d0c8f1328b..1b89dc6274 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -33,11 +33,9 @@ from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id import synapse.metrics -from ._base import BaseHandler - import logging @@ -52,6 +50,13 @@ timers_fired_counter = metrics.register_counter("timers_fired") federation_presence_counter = metrics.register_counter("federation_presence") bump_active_time_counter = metrics.register_counter("bump_active_time") +get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) + +notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"]) +state_transition_counter = metrics.register_counter( + "state_transition", labels=["from", "to"] +) + # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # "currently_active" @@ -70,38 +75,45 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000 # How often to resend presence to remote servers FEDERATION_PING_INTERVAL = 25 * 60 * 1000 +# How long we will wait before assuming that the syncs from an external process +# are dead. +EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000 + assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER -class PresenceHandler(BaseHandler): +class PresenceHandler(object): def __init__(self, hs): - super(PresenceHandler, self).__init__(hs) - self.hs = hs + self.is_mine = hs.is_mine + self.is_mine_id = hs.is_mine_id self.clock = hs.get_clock() self.store = hs.get_datastore() self.wheel_timer = WheelTimer() self.notifier = hs.get_notifier() - self.federation = hs.get_replication_layer() + self.replication = hs.get_replication_layer() + self.federation = hs.get_federation_sender() + + self.state = hs.get_state_handler() - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence", self.incoming_presence ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_invite", lambda origin, content: self.invite_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_accept", lambda origin, content: self.accept_presence( observed_user=UserID.from_string(content["observed_user"]), observer_user=UserID.from_string(content["observer_user"]), ) ) - self.federation.register_edu_handler( + self.replication.register_edu_handler( "m.presence_deny", lambda origin, content: self.deny_presence( observed_user=UserID.from_string(content["observed_user"]), @@ -138,7 +150,7 @@ class PresenceHandler(BaseHandler): obj=state.user_id, then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, ) - if self.hs.is_mine_id(state.user_id): + if self.is_mine_id(state.user_id): self.wheel_timer.insert( now=now, obj=state.user_id, @@ -160,20 +172,38 @@ class PresenceHandler(BaseHandler): self.serial_to_user = {} self._next_serial = 1 - # Keeps track of the number of *ongoing* syncs. While this is non zero - # a user will never go offline. + # Keeps track of the number of *ongoing* syncs on this process. While + # this is non zero a user will never go offline. self.user_to_num_current_syncs = {} + # Keeps track of the number of *ongoing* syncs on other processes. + # While any sync is ongoing on another process the user will never + # go offline. + # Each process has a unique identifier and an update frequency. If + # no update is received from that process within the update period then + # we assume that all the sync requests on that process have stopped. + # Stored as a dict from process_id to set of user_id, and a dict of + # process_id to millisecond timestamp last updated. + self.external_process_to_current_syncs = {} + self.external_process_last_updated_ms = {} + # Start a LoopingCall in 30s that fires every 5s. # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. self.clock.call_later( - 0 * 1000, + 30, self.clock.looping_call, self._handle_timeouts, 5000, ) + self.clock.call_later( + 60, + self.clock.looping_call, + self._persist_unpersisted_changes, + 60 * 1000, + ) + metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) @defer.inlineCallbacks @@ -188,7 +218,7 @@ class PresenceHandler(BaseHandler): is some spurious presence changes that will self-correct. """ logger.info( - "Performing _on_shutdown. Persiting %d unpersisted changes", + "Performing _on_shutdown. Persisting %d unpersisted changes", len(self.user_to_current_state) ) @@ -200,6 +230,27 @@ class PresenceHandler(BaseHandler): logger.info("Finished _on_shutdown") @defer.inlineCallbacks + def _persist_unpersisted_changes(self): + """We periodically persist the unpersisted changes, as otherwise they + may stack up and slow down shutdown times. + """ + logger.info( + "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes", + len(self.unpersisted_users_changes) + ) + + unpersisted = self.unpersisted_users_changes + self.unpersisted_users_changes = set() + + if unpersisted: + yield self.store.update_presence([ + self.user_to_current_state[user_id] + for user_id in unpersisted + ]) + + logger.info("Finished _persist_unpersisted_changes") + + @defer.inlineCallbacks def _update_states(self, new_states): """Updates presence of users. Sets the appropriate timeouts. Pokes the notifier and federation if and only if the changed presence state @@ -215,6 +266,12 @@ class PresenceHandler(BaseHandler): to_notify = {} # Changes we want to notify everyone about to_federation_ping = {} # These need sending keep-alives + # Only bother handling the last presence change for each user + new_states_dict = {} + for new_state in new_states: + new_states_dict[new_state.user_id] = new_state + new_state = new_states_dict.values() + for new_state in new_states: user_id = new_state.user_id @@ -228,7 +285,7 @@ class PresenceHandler(BaseHandler): new_state, should_notify, should_ping = handle_update( prev_state, new_state, - is_mine=self.hs.is_mine_id(user_id), + is_mine=self.is_mine_id(user_id), wheel_timer=self.wheel_timer, now=now ) @@ -268,31 +325,48 @@ class PresenceHandler(BaseHandler): """Checks the presence of users that have timed out and updates as appropriate. """ + logger.info("Handling presence timeouts") now = self.clock.time_msec() - with Measure(self.clock, "presence_handle_timeouts"): - # Fetch the list of users that *may* have timed out. Things may have - # changed since the timeout was set, so we won't necessarily have to - # take any action. - users_to_check = self.wheel_timer.fetch(now) + try: + with Measure(self.clock, "presence_handle_timeouts"): + # Fetch the list of users that *may* have timed out. Things may have + # changed since the timeout was set, so we won't necessarily have to + # take any action. + users_to_check = set(self.wheel_timer.fetch(now)) + + # Check whether the lists of syncing processes from an external + # process have expired. + expired_process_ids = [ + process_id for process_id, last_update + in self.external_process_last_updated_ms.items() + if now - last_update > EXTERNAL_PROCESS_EXPIRY + ] + for process_id in expired_process_ids: + users_to_check.update( + self.external_process_last_updated_ms.pop(process_id, ()) + ) + self.external_process_last_update.pop(process_id) - states = [ - self.user_to_current_state.get( - user_id, UserPresenceState.default(user_id) - ) - for user_id in set(users_to_check) - ] + states = [ + self.user_to_current_state.get( + user_id, UserPresenceState.default(user_id) + ) + for user_id in users_to_check + ] - timers_fired_counter.inc_by(len(states)) + timers_fired_counter.inc_by(len(states)) - changes = handle_timeouts( - states, - is_mine_fn=self.hs.is_mine_id, - user_to_num_current_syncs=self.user_to_num_current_syncs, - now=now, - ) + changes = handle_timeouts( + states, + is_mine_fn=self.is_mine_id, + syncing_user_ids=self.get_currently_syncing_users(), + now=now, + ) - preserve_fn(self._update_states)(changes) + preserve_fn(self._update_states)(changes) + except: + logger.exception("Exception in _handle_timeouts loop") @defer.inlineCallbacks def bump_presence_active_time(self, user): @@ -365,6 +439,74 @@ class PresenceHandler(BaseHandler): defer.returnValue(_user_syncing()) + def get_currently_syncing_users(self): + """Get the set of user ids that are currently syncing on this HS. + Returns: + set(str): A set of user_id strings. + """ + syncing_user_ids = { + user_id for user_id, count in self.user_to_num_current_syncs.items() + if count + } + for user_ids in self.external_process_to_current_syncs.values(): + syncing_user_ids.update(user_ids) + return syncing_user_ids + + @defer.inlineCallbacks + def update_external_syncs(self, process_id, syncing_user_ids): + """Update the syncing users for an external process + + Args: + process_id(str): An identifier for the process the users are + syncing against. This allows synapse to process updates + as user start and stop syncing against a given process. + syncing_user_ids(set(str)): The set of user_ids that are + currently syncing on that server. + """ + + # Grab the previous list of user_ids that were syncing on that process + prev_syncing_user_ids = ( + self.external_process_to_current_syncs.get(process_id, set()) + ) + # Grab the current presence state for both the users that are syncing + # now and the users that were syncing before this update. + prev_states = yield self.current_state_for_users( + syncing_user_ids | prev_syncing_user_ids + ) + updates = [] + time_now_ms = self.clock.time_msec() + + # For each new user that is syncing check if we need to mark them as + # being online. + for new_user_id in syncing_user_ids - prev_syncing_user_ids: + prev_state = prev_states[new_user_id] + if prev_state.state == PresenceState.OFFLINE: + updates.append(prev_state.copy_and_replace( + state=PresenceState.ONLINE, + last_active_ts=time_now_ms, + last_user_sync_ts=time_now_ms, + )) + else: + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + # For each user that is still syncing or stopped syncing update the + # last sync time so that we will correctly apply the grace period when + # they stop syncing. + for old_user_id in prev_syncing_user_ids: + prev_state = prev_states[old_user_id] + updates.append(prev_state.copy_and_replace( + last_user_sync_ts=time_now_ms, + )) + + yield self._update_states(updates) + + # Update the last updated time for the process. We expire the entries + # if we don't receive an update in the given timeframe. + self.external_process_last_updated_ms[process_id] = self.clock.time_msec() + self.external_process_to_current_syncs[process_id] = syncing_user_ids + @defer.inlineCallbacks def current_state_for_user(self, user_id): """Get the current presence state for a user. @@ -403,7 +545,7 @@ class PresenceHandler(BaseHandler): defer.returnValue(states) @defer.inlineCallbacks - def _get_interested_parties(self, states): + def _get_interested_parties(self, states, calculate_remote_hosts=True): """Given a list of states return which entities (rooms, users, servers) are interested in the given states. @@ -426,21 +568,24 @@ class PresenceHandler(BaseHandler): users_to_states.setdefault(state.user_id, []).append(state) hosts_to_states = {} - for room_id, states in room_ids_to_states.items(): - local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) - if not local_states: - continue + if calculate_remote_hosts: + for room_id, states in room_ids_to_states.items(): + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) + if not local_states: + continue + + users = yield self.state.get_current_user_in_room(room_id) + hosts = set(get_domain_from_id(u) for u in users) - hosts = yield self.store.get_joined_hosts_for_room(room_id) - for host in hosts: - hosts_to_states.setdefault(host, []).extend(local_states) + for host in hosts: + hosts_to_states.setdefault(host, []).extend(local_states) for user_id, states in users_to_states.items(): - local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states) + local_states = filter(lambda s: self.is_mine_id(s.user_id), states) if not local_states: continue - host = UserID.from_string(user_id).domain + host = get_domain_from_id(user_id) hosts_to_states.setdefault(host, []).extend(local_states) # TODO: de-dup hosts_to_states, as a single host might have multiple @@ -465,24 +610,24 @@ class PresenceHandler(BaseHandler): self._push_to_remotes(hosts_to_states) + @defer.inlineCallbacks + def notify_for_states(self, state, stream_id): + parties = yield self._get_interested_parties([state]) + room_ids_to_states, users_to_states, hosts_to_states = parties + + self.notifier.on_new_event( + "presence_key", stream_id, rooms=room_ids_to_states.keys(), + users=[UserID.from_string(u) for u in users_to_states.keys()] + ) + def _push_to_remotes(self, hosts_to_states): """Sends state updates to remote servers. Args: hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` """ - now = self.clock.time_msec() for host, states in hosts_to_states.items(): - self.federation.send_edu( - destination=host, - edu_type="m.presence", - content={ - "push": [ - _format_user_presence_state(state, now) - for state in states - ] - } - ) + self.federation.send_presence(host, states) @defer.inlineCallbacks def incoming_presence(self, origin, content): @@ -503,6 +648,13 @@ class PresenceHandler(BaseHandler): ) continue + if get_domain_from_id(user_id) != origin: + logger.info( + "Got presence update from %r with bad 'user_id': %r", + origin, user_id, + ) + continue + presence_state = push.get("presence", None) if not presence_state: logger.info( @@ -562,17 +714,17 @@ class PresenceHandler(BaseHandler): defer.returnValue([ { "type": "m.presence", - "content": _format_user_presence_state(state, now), + "content": format_user_presence_state(state, now), } for state in updates ]) else: defer.returnValue([ - _format_user_presence_state(state, now) for state in updates + format_user_presence_state(state, now) for state in updates ]) @defer.inlineCallbacks - def set_state(self, target_user, state): + def set_state(self, target_user, state, ignore_status_msg=False): """Set the presence state of the user. """ status_msg = state.get("status_msg", None) @@ -589,10 +741,13 @@ class PresenceHandler(BaseHandler): prev_state = yield self.current_state_for_user(user_id) new_fields = { - "state": presence, - "status_msg": status_msg if presence != PresenceState.OFFLINE else None + "state": presence } + if not ignore_status_msg: + msg = status_msg if presence != PresenceState.OFFLINE else None + new_fields["status_msg"] = msg + if presence == PresenceState.ONLINE: new_fields["last_active_ts"] = self.clock.time_msec() @@ -611,14 +766,14 @@ class PresenceHandler(BaseHandler): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. - if self.hs.is_mine(user): + user_ids = yield self.state.get_current_user_in_room(room_id) + if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = yield self.store.get_joined_hosts_for_room(room_id) + hosts = set(get_domain_from_id(u) for u in user_ids) self._push_to_remotes({host: (state,) for host in hosts}) else: - user_ids = yield self.store.get_users_in_room(room_id) - user_ids = filter(self.hs.is_mine_id, user_ids) + user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) @@ -628,7 +783,7 @@ class PresenceHandler(BaseHandler): def get_presence_list(self, observer_user, accepted=None): """Returns the presence for all users in their presence list. """ - if not self.hs.is_mine(observer_user): + if not self.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") presence_list = yield self.store.get_presence_list( @@ -659,7 +814,7 @@ class PresenceHandler(BaseHandler): observer_user.localpart, observed_user.to_string() ) - if self.hs.is_mine(observed_user): + if self.is_mine(observed_user): yield self.invite_presence(observed_user, observer_user) else: yield self.federation.send_edu( @@ -675,11 +830,11 @@ class PresenceHandler(BaseHandler): def invite_presence(self, observed_user, observer_user): """Handles new presence invites. """ - if not self.hs.is_mine(observed_user): + if not self.is_mine(observed_user): raise SynapseError(400, "User is not hosted on this Home Server") # TODO: Don't auto accept - if self.hs.is_mine(observer_user): + if self.is_mine(observer_user): yield self.accept_presence(observed_user, observer_user) else: self.federation.send_edu( @@ -742,7 +897,7 @@ class PresenceHandler(BaseHandler): Returns: A Deferred. """ - if not self.hs.is_mine(observer_user): + if not self.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") yield self.store.del_presence_list( @@ -793,28 +948,38 @@ class PresenceHandler(BaseHandler): def should_notify(old_state, new_state): """Decides if a presence state change should be sent to interested parties. """ + if old_state == new_state: + return False + if old_state.status_msg != new_state.status_msg: + notify_reason_counter.inc("status_msg_change") return True - if old_state.state == PresenceState.ONLINE: - if new_state.state != PresenceState.ONLINE: - # Always notify for online -> anything - return True + if old_state.state != new_state.state: + notify_reason_counter.inc("state_change") + state_transition_counter.inc(old_state.state, new_state.state) + return True + if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: + notify_reason_counter.inc("current_active_change") return True - if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: - # Always notify for a transition where last active gets bumped. - return True + if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Only notify about last active bumps if we're not currently acive + if not new_state.currently_active: + notify_reason_counter.inc("last_active_change_online") + return True - if old_state.state != new_state.state: + elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: + # Always notify for a transition where last active gets bumped. + notify_reason_counter.inc("last_active_change_not_online") return True return False -def _format_user_presence_state(state, now): +def format_user_presence_state(state, now): """Convert UserPresenceState to a format that can be sent down to clients and to other servers. """ @@ -834,9 +999,14 @@ def _format_user_presence_state(state, now): class PresenceEventSource(object): def __init__(self, hs): - self.hs = hs + # We can't call get_presence_handler here because there's a cycle: + # + # Presence -> Notifier -> PresenceEventSource -> Presence + # + self.get_presence_handler = hs.get_presence_handler self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -860,7 +1030,7 @@ class PresenceEventSource(object): from_key = int(from_key) room_ids = room_ids or [] - presence = self.hs.get_handlers().presence_handler + presence = self.get_presence_handler() stream_change_cache = self.store.presence_stream_cache if not room_ids: @@ -877,13 +1047,13 @@ class PresenceEventSource(object): user_ids_changed = set() changed = None - if from_key and max_token - from_key < 100: - # For small deltas, its quicker to get all changes and then - # work out if we share a room or they're in our presence list + if from_key: changed = stream_change_cache.get_all_entities_changed(from_key) - # get_all_entities_changed can return None - if changed is not None: + if changed is not None and len(changed) < 500: + # For small deltas, its quicker to get all changes and then + # work out if we share a room or they're in our presence list + get_updates_counter.inc("stream") for other_user_id in changed: if other_user_id in friends: user_ids_changed.add(other_user_id) @@ -895,9 +1065,11 @@ class PresenceEventSource(object): else: # Too many possible updates. Find all users we can see and check # if any of them have changed. + get_updates_counter.inc("full") + user_ids_to_check = set() for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) user_ids_to_check.update(users) user_ids_to_check.update(friends) @@ -920,7 +1092,7 @@ class PresenceEventSource(object): defer.returnValue(([ { "type": "m.presence", - "content": _format_user_presence_state(s, now), + "content": format_user_presence_state(s, now), } for s in updates.values() if include_offline or s.state != PresenceState.OFFLINE @@ -933,15 +1105,14 @@ class PresenceEventSource(object): return self.get_new_events(user, from_key=None, include_offline=False) -def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): +def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): """Checks the presence of users that have timed out and updates as appropriate. Args: user_states(list): List of UserPresenceState's to check. is_mine_fn (fn): Function that returns if a user_id is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -952,21 +1123,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): for state in user_states: is_mine = is_mine_fn(state.user_id) - new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) + new_state = handle_timeout(state, is_mine, syncing_user_ids, now) if new_state: changes[state.user_id] = new_state return changes.values() -def handle_timeout(state, is_mine, user_to_num_current_syncs, now): +def handle_timeout(state, is_mine, syncing_user_ids, now): """Checks the presence of the user to see if any of the timers have elapsed Args: state (UserPresenceState) is_mine (bool): Whether the user is ours - user_to_num_current_syncs (dict): Mapping of user_id to number of currently - active syncs. + syncing_user_ids (set): Set of user_ids with active syncs. now (int): Current time in ms. Returns: @@ -1000,7 +1170,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now): # If there are have been no sync for a while (and none ongoing), # set presence to offline - if not user_to_num_current_syncs.get(user_id, 0): + if user_id not in syncing_user_ids: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: state = state.copy_and_replace( state=PresenceState.OFFLINE, diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b45eafbb49..87f74dfb8e 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -13,28 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from twisted.internet import defer +import synapse.types from synapse.api.errors import SynapseError, AuthError, CodeMessageException -from synapse.types import UserID, Requester -from synapse.util import unwrapFirstError - +from synapse.types import UserID from ._base import BaseHandler -import logging - logger = logging.getLogger(__name__) -def changed_presencelike_data(distributor, user, state): - return distributor.fire("changed_presencelike_data", user, state) - - -def collect_presencelike_data(distributor, user, content): - return distributor.fire("collect_presencelike_data", user, content) - - class ProfileHandler(BaseHandler): def __init__(self, hs): @@ -45,21 +36,6 @@ class ProfileHandler(BaseHandler): "profile", self.on_profile_query ) - distributor = hs.get_distributor() - self.distributor = distributor - - distributor.declare("collect_presencelike_data") - distributor.declare("changed_presencelike_data") - - distributor.observe("registered_user", self.registered_user) - - distributor.observe( - "collect_presencelike_data", self.collect_presencelike_data - ) - - def registered_user(self, user): - return self.store.create_profile(user.localpart) - @defer.inlineCallbacks def get_displayname(self, target_user): if self.hs.is_mine(target_user): @@ -89,13 +65,13 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["displayname"]) @defer.inlineCallbacks - def set_displayname(self, target_user, requester, new_displayname): + def set_displayname(self, target_user, requester, new_displayname, by_admin=False): """target_user is the user whose displayname is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") if new_displayname == '': @@ -105,10 +81,6 @@ class ProfileHandler(BaseHandler): target_user.localpart, new_displayname ) - yield changed_presencelike_data(self.distributor, target_user, { - "displayname": new_displayname, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks @@ -139,44 +111,22 @@ class ProfileHandler(BaseHandler): defer.returnValue(result["avatar_url"]) @defer.inlineCallbacks - def set_avatar_url(self, target_user, requester, new_avatar_url): + def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): """target_user is the user whose avatar_url is to be changed; auth_user is the user attempting to make this change.""" if not self.hs.is_mine(target_user): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != requester.user: + if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") yield self.store.set_profile_avatar_url( target_user.localpart, new_avatar_url ) - yield changed_presencelike_data(self.distributor, target_user, { - "avatar_url": new_avatar_url, - }) - yield self._update_join_states(requester) @defer.inlineCallbacks - def collect_presencelike_data(self, user, state): - if not self.hs.is_mine(user): - defer.returnValue(None) - - (displayname, avatar_url) = yield defer.gatherResults( - [ - self.store.get_profile_displayname(user.localpart), - self.store.get_profile_avatar_url(user.localpart), - ], - consumeErrors=True - ).addErrback(unwrapFirstError) - - state["displayname"] = displayname - state["avatar_url"] = avatar_url - - defer.returnValue(None) - - @defer.inlineCallbacks def on_profile_query(self, args): user = UserID.from_string(args["user_id"]) if not self.hs.is_mine(user): @@ -215,7 +165,9 @@ class ProfileHandler(BaseHandler): try: # Assume the user isn't a guest because we don't let guests set # profile or avatar data. - requester = Requester(user, "", False) + # XXX why are we recreating `requester` here for each room? + # what was wrong with the `requester` we were passed? + requester = synapse.types.create_requester(user) yield handler.update_membership( requester, user, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 935c339707..916e80a48e 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from ._base import BaseHandler from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import get_domain_from_id import logging @@ -29,12 +30,15 @@ class ReceiptsHandler(BaseHandler): def __init__(self, hs): super(ReceiptsHandler, self).__init__(hs) + self.server_name = hs.config.server_name + self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_replication_layer() - self.federation.register_edu_handler( + self.federation = hs.get_federation_sender() + hs.get_replication_layer().register_edu_handler( "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() + self.state = hs.get_state_handler() @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, @@ -80,6 +84,9 @@ class ReceiptsHandler(BaseHandler): def _handle_new_receipts(self, receipts): """Takes a list of receipts, stores them and informs the notifier. """ + min_batch_id = None + max_batch_id = None + for receipt in receipts: room_id = receipt["room_id"] receipt_type = receipt["receipt_type"] @@ -97,10 +104,21 @@ class ReceiptsHandler(BaseHandler): stream_id, max_persisted_id = res - with PreserveLoggingContext(): - self.notifier.on_new_event( - "receipt_key", max_persisted_id, rooms=[room_id] - ) + if min_batch_id is None or stream_id < min_batch_id: + min_batch_id = stream_id + if max_batch_id is None or max_persisted_id > max_batch_id: + max_batch_id = max_persisted_id + + affected_room_ids = list(set([r["room_id"] for r in receipts])) + + with PreserveLoggingContext(): + self.notifier.on_new_event( + "receipt_key", max_batch_id, rooms=affected_room_ids + ) + # Note that the min here shouldn't be relied upon to be accurate. + self.hs.get_pusherpool().on_new_receipts( + min_batch_id, max_batch_id, affected_room_ids + ) defer.returnValue(True) @@ -117,12 +135,10 @@ class ReceiptsHandler(BaseHandler): event_ids = receipt["event_ids"] data = receipt["data"] - remotedomains = set() - - rm_handler = self.hs.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=None, remotedomains=remotedomains - ) + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) logger.debug("Sending receipt to: %r", remotedomains) @@ -140,6 +156,7 @@ class ReceiptsHandler(BaseHandler): } }, }, + key=(room_id, receipt_type, user_id), ) @defer.inlineCallbacks diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index f287ee247b..886fec8701 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,34 +14,28 @@ # limitations under the License. """Contains functions for registering clients.""" +import logging +import urllib + from twisted.internet import defer -from synapse.types import UserID from synapse.api.errors import ( AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError ) -from ._base import BaseHandler -from synapse.util.async import run_on_reactor from synapse.http.client import CaptchaServerHttpClient - -import logging -import urllib +from synapse.types import UserID +from synapse.util.async import run_on_reactor +from ._base import BaseHandler logger = logging.getLogger(__name__) -def registered_user(distributor, user): - return distributor.fire("registered_user", user) - - class RegistrationHandler(BaseHandler): def __init__(self, hs): super(RegistrationHandler, self).__init__(hs) self.auth = hs.get_auth() - self.distributor = hs.get_distributor() - self.distributor.declare("registered_user") self.captcha_client = CaptchaServerHttpClient(hs) self._next_generated_user_id = None @@ -58,6 +52,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME ) + if localpart[0] == '_': + raise SynapseError( + 400, + "User ID may not begin with _", + Codes.INVALID_USERNAME + ) + user = UserID(localpart, self.hs.hostname) user_id = user.to_string() @@ -96,7 +97,8 @@ class RegistrationHandler(BaseHandler): password=None, generate_token=True, guest_access_token=None, - make_guest=False + make_guest=False, + admin=False, ): """Registers a new client on the server. @@ -104,8 +106,13 @@ class RegistrationHandler(BaseHandler): localpart : The local part of the user ID to register. If None, one will be generated. password (str) : The password to assign to this user so they can - login again. This can be None which means they cannot login again - via a password (e.g. the user is an application service user). + login again. This can be None which means they cannot login again + via a password (e.g. the user is an application service user). + generate_token (bool): Whether a new access token should be + generated. Having this be True should be considered deprecated, + since it offers no means of associating a device_id with the + access_token. Instead you should call auth_handler.issue_access_token + after registration. Returns: A tuple of (user_id, access_token). Raises: @@ -143,9 +150,12 @@ class RegistrationHandler(BaseHandler): password_hash=password_hash, was_guest=was_guest, make_guest=make_guest, + create_profile_with_localpart=( + # If the user was a guest then they already have a profile + None if was_guest else user.localpart + ), + admin=admin, ) - - yield registered_user(self.distributor, user) else: # autogen a sequential user ID attempts = 0 @@ -163,7 +173,8 @@ class RegistrationHandler(BaseHandler): user_id=user_id, token=token, password_hash=password_hash, - make_guest=make_guest + make_guest=make_guest, + create_profile_with_localpart=user.localpart, ) except SynapseError: # if user id is taken, just generate another @@ -171,7 +182,6 @@ class RegistrationHandler(BaseHandler): user_id = None token = None attempts += 1 - yield registered_user(self.distributor, user) # We used to generate default identicons here, but nowadays # we want clients to generate their own as part of their branding @@ -183,7 +193,7 @@ class RegistrationHandler(BaseHandler): def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() - service = yield self.store.get_app_service_by_token(as_token) + service = self.store.get_app_service_by_token(as_token) if not service: raise AuthError(403, "Invalid application service token.") if not service.is_interested_in_user(user_id): @@ -198,15 +208,13 @@ class RegistrationHandler(BaseHandler): user_id, allowed_appservice=service ) - token = self.auth_handler().generate_access_token(user_id) yield self.store.register( user_id=user_id, - token=token, password_hash="", appservice_id=service_id, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) - defer.returnValue((user_id, token)) + defer.returnValue(user_id) @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): @@ -251,9 +259,9 @@ class RegistrationHandler(BaseHandler): yield self.store.register( user_id=user_id, token=token, - password_hash=None + password_hash=None, + create_profile_with_localpart=user.localpart, ) - yield registered_user(self.distributor, user) except Exception as e: yield self.store.add_access_token_to_user(user_id, token) # Ignore Registration errors @@ -296,11 +304,10 @@ class RegistrationHandler(BaseHandler): # XXX: This should be a deferred list, shouldn't it? yield identity_handler.bind_threepid(c, user_id) - @defer.inlineCallbacks def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): # valid user IDs must not clash with any user ID namespaces claimed by # application services. - services = yield self.store.get_app_services() + services = self.store.get_app_services() interested_services = [ s for s in services if s.is_interested_in_user(user_id) @@ -361,8 +368,61 @@ class RegistrationHandler(BaseHandler): ) defer.returnValue(data) + @defer.inlineCallbacks + def get_or_create_user(self, requester, localpart, displayname, + password_hash=None): + """Creates a new user if the user does not exist, + else revokes all previous access tokens and generates a new one. + + Args: + localpart : The local part of the user ID to register. If None, + one will be randomly generated. + Returns: + A tuple of (user_id, access_token). + Raises: + RegistrationError if there was a problem registering. + """ + yield run_on_reactor() + + if localpart is None: + raise SynapseError(400, "Request must include user id") + + need_register = True + + try: + yield self.check_username(localpart) + except SynapseError as e: + if e.errcode == Codes.USER_IN_USE: + need_register = False + else: + raise + + user = UserID(localpart, self.hs.hostname) + user_id = user.to_string() + token = self.auth_handler().generate_access_token(user_id) + + if need_register: + yield self.store.register( + user_id=user_id, + token=token, + password_hash=password_hash, + create_profile_with_localpart=user.localpart, + ) + else: + yield self.store.user_delete_access_tokens(user_id=user_id) + yield self.store.add_access_token_to_user(user_id=user_id, token=token) + + if displayname is not None: + logger.info("setting user display name: %s -> %s", user_id, displayname) + profile_handler = self.hs.get_handlers().profile_handler + yield profile_handler.set_displayname( + user, requester, displayname, by_admin=True, + ) + + defer.returnValue((user_id, token)) + def auth_handler(self): - return self.hs.get_handlers().auth_handler + return self.hs.get_auth_handler() @defer.inlineCallbacks def guest_access_token_for(self, medium, address, inviter_user_id): diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f7163470a9..5f18007e90 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -18,19 +18,15 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester +from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.api.constants import ( - EventTypes, Membership, JoinRules, RoomCreationPreset, + EventTypes, JoinRules, RoomCreationPreset ) -from synapse.api.errors import AuthError, StoreError, SynapseError, Codes -from synapse.util import stringutils, unwrapFirstError -from synapse.util.logcontext import preserve_context_over_fn - -from signedjson.sign import verify_signed_json -from signedjson.key import decode_verify_key_bytes +from synapse.api.errors import AuthError, StoreError, SynapseError +from synapse.util import stringutils +from synapse.visibility import filter_events_for_client from collections import OrderedDict -from unpaddedbase64 import decode_base64 import logging import math @@ -41,20 +37,6 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -def user_left_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_left_room", user=user, room_id=room_id - ) - - -def user_joined_room(distributor, user, room_id): - return preserve_context_over_fn( - distributor.fire, - "user_joined_room", user=user, room_id=room_id - ) - - class RoomCreationHandler(BaseHandler): PRESETS_DICT = { @@ -122,7 +104,8 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = config.get("invite_3pid", []) - is_public = config.get("visibility", None) == "public" + visibility = config.get("visibility", None) + is_public = visibility == "public" # autogen room IDs and try to create it. We may clash, so just # try a few times till one goes through, giving up eventually. @@ -158,9 +141,9 @@ class RoomCreationHandler(BaseHandler): preset_config = config.get( "preset", - RoomCreationPreset.PUBLIC_CHAT - if is_public - else RoomCreationPreset.PRIVATE_CHAT + RoomCreationPreset.PRIVATE_CHAT + if visibility == "private" + else RoomCreationPreset.PUBLIC_CHAT ) raw_initial_state = config.get("initial_state", []) @@ -212,6 +195,11 @@ class RoomCreationHandler(BaseHandler): }, ratelimit=False) + content = {} + is_direct = config.get("is_direct", None) + if is_direct: + content["is_direct"] = is_direct + for invitee in invite_list: yield room_member_handler.update_membership( requester, @@ -219,6 +207,7 @@ class RoomCreationHandler(BaseHandler): room_id, "invite", ratelimit=False, + content=content, ) for invite_3pid in invite_3pid_list: @@ -365,659 +354,6 @@ class RoomCreationHandler(BaseHandler): ) -class RoomMemberHandler(BaseHandler): - # TODO(paul): This handler currently contains a messy conflation of - # low-level API that works on UserID objects and so on, and REST-level - # API that takes ID strings and returns pagination chunks. These concerns - # ought to be separated out a lot better. - - def __init__(self, hs): - super(RoomMemberHandler, self).__init__(hs) - - self.clock = hs.get_clock() - - self.distributor = hs.get_distributor() - self.distributor.declare("user_joined_room") - self.distributor.declare("user_left_room") - - @defer.inlineCallbacks - def get_room_members(self, room_id): - users = yield self.store.get_users_in_room(room_id) - - defer.returnValue([UserID.from_string(u) for u in users]) - - @defer.inlineCallbacks - def fetch_room_distributions_into(self, room_id, localusers=None, - remotedomains=None, ignore_user=None): - """Fetch the distribution of a room, adding elements to either - 'localusers' or 'remotedomains', which should be a set() if supplied. - If ignore_user is set, ignore that user. - - This function returns nothing; its result is performed by the - side-effect on the two passed sets. This allows easy accumulation of - member lists of multiple rooms at once if required. - """ - members = yield self.get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if self.hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - - @defer.inlineCallbacks - def update_membership( - self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - ): - effective_membership_state = action - if action in ["kick", "unban"]: - effective_membership_state = "leave" - elif action == "forget": - effective_membership_state = "leave" - - if third_party_signed is not None: - replication = self.hs.get_replication_layer() - yield replication.exchange_third_party_invite( - third_party_signed["sender"], - target.to_string(), - room_id, - third_party_signed, - ) - - msg_handler = self.hs.get_handlers().message_handler - - content = {"membership": effective_membership_state} - if requester.is_guest: - content["kind"] = "guest" - - event, context = yield msg_handler.create_event( - { - "type": EventTypes.Member, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "state_key": target.to_string(), - - # For backwards compatibility: - "membership": effective_membership_state, - }, - token_id=requester.access_token_id, - txn_id=txn_id, - ) - - old_state = context.current_state.get((EventTypes.Member, event.state_key)) - old_membership = old_state.content.get("membership") if old_state else None - if action == "unban" and old_membership != "ban": - raise SynapseError( - 403, - "Cannot unban user who was not banned (membership=%s)" % old_membership, - errcode=Codes.BAD_STATE - ) - if old_membership == "ban" and action != "unban": - raise SynapseError( - 403, - "Cannot %s user who was is banned" % (action,), - errcode=Codes.BAD_STATE - ) - - member_handler = self.hs.get_handlers().room_member_handler - yield member_handler.send_membership_event( - requester, - event, - context, - ratelimit=ratelimit, - remote_room_hosts=remote_room_hosts, - ) - - if action == "forget": - yield self.forget(requester.user, room_id) - - @defer.inlineCallbacks - def send_membership_event( - self, - requester, - event, - context, - remote_room_hosts=None, - ratelimit=True, - ): - """ - Change the membership status of a user in a room. - - Args: - requester (Requester): The local user who requested the membership - event. If None, certain checks, like whether this homeserver can - act as the sender, will be skipped. - event (SynapseEvent): The membership event. - context: The context of the event. - is_guest (bool): Whether the sender is a guest. - room_hosts ([str]): Homeservers which are likely to already be in - the room, and could be danced with in order to join this - homeserver for the first time. - ratelimit (bool): Whether to rate limit this request. - Raises: - SynapseError if there was a problem changing the membership. - """ - remote_room_hosts = remote_room_hosts or [] - - target_user = UserID.from_string(event.state_key) - room_id = event.room_id - - if requester is not None: - sender = UserID.from_string(event.sender) - assert sender == requester.user, ( - "Sender (%s) must be same as requester (%s)" % - (sender, requester.user) - ) - assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) - else: - requester = Requester(target_user, None, False) - - message_handler = self.hs.get_handlers().message_handler - prev_event = message_handler.deduplicate_state_event(event, context) - if prev_event is not None: - return - - action = "send" - - if event.membership == Membership.JOIN: - if requester.is_guest and not self._can_guest_join(context.current_state): - # This should be an auth check, but guests are a local concept, - # so don't really fit into the general auth process. - raise AuthError(403, "Guest access not allowed") - do_remote_join_dance, remote_room_hosts = self._should_do_dance( - context, - (self.get_inviter(event.state_key, context.current_state)), - remote_room_hosts, - ) - if do_remote_join_dance: - action = "remote_join" - elif event.membership == Membership.LEAVE: - is_host_in_room = self.is_host_in_room(context.current_state) - - if not is_host_in_room: - # perhaps we've been invited - inviter = self.get_inviter(target_user.to_string(), context.current_state) - if not inviter: - raise SynapseError(404, "Not a known room") - - if self.hs.is_mine(inviter): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: - # send the rejection to the inviter's HS. - remote_room_hosts = remote_room_hosts + [inviter.domain] - action = "remote_reject" - - federation_handler = self.hs.get_handlers().federation_handler - - if action == "remote_join": - if len(remote_room_hosts) == 0: - raise SynapseError(404, "No known servers") - - # We don't do an auth check if we are doing an invite - # join dance for now, since we're kinda implicitly checking - # that we are allowed to join when we decide whether or not we - # need to do the invite/join dance. - yield federation_handler.do_invite_join( - remote_room_hosts, - event.room_id, - event.user_id, - event.content, - ) - elif action == "remote_reject": - yield federation_handler.do_remotely_reject_invite( - remote_room_hosts, - room_id, - event.user_id - ) - else: - yield self.handle_new_client_event( - requester, - event, - context, - extra_users=[target_user], - ratelimit=ratelimit, - ) - - prev_member_event = context.current_state.get( - (EventTypes.Member, target_user.to_string()), - None - ) - - if event.membership == Membership.JOIN: - if not prev_member_event or prev_member_event.membership != Membership.JOIN: - # Only fire user_joined_room if the user has acutally joined the - # room. Don't bother if the user is just changing their profile - # info. - yield user_joined_room(self.distributor, target_user, room_id) - elif event.membership == Membership.LEAVE: - if prev_member_event and prev_member_event.membership == Membership.JOIN: - user_left_room(self.distributor, target_user, room_id) - - def _can_guest_join(self, current_state): - """ - Returns whether a guest can join a room based on its current state. - """ - guest_access = current_state.get((EventTypes.GuestAccess, ""), None) - return ( - guest_access - and guest_access.content - and "guest_access" in guest_access.content - and guest_access.content["guest_access"] == "can_join" - ) - - def _should_do_dance(self, context, inviter, room_hosts=None): - # TODO: Shouldn't this be remote_room_host? - room_hosts = room_hosts or [] - - is_host_in_room = self.is_host_in_room(context.current_state) - if is_host_in_room: - return False, room_hosts - - if inviter and not self.hs.is_mine(inviter): - room_hosts.append(inviter.domain) - - return True, room_hosts - - @defer.inlineCallbacks - def lookup_room_alias(self, room_alias): - """ - Get the room ID associated with a room alias. - - Args: - room_alias (RoomAlias): The alias to look up. - Returns: - A tuple of: - The room ID as a RoomID object. - Hosts likely to be participating in the room ([str]). - Raises: - SynapseError if room alias could not be found. - """ - directory_handler = self.hs.get_handlers().directory_handler - mapping = yield directory_handler.get_association(room_alias) - - if not mapping: - raise SynapseError(404, "No such room alias") - - room_id = mapping["room_id"] - servers = mapping["servers"] - - defer.returnValue((RoomID.from_string(room_id), servers)) - - def get_inviter(self, user_id, current_state): - prev_state = current_state.get((EventTypes.Member, user_id)) - if prev_state and prev_state.membership == Membership.INVITE: - return UserID.from_string(prev_state.user_id) - return None - - @defer.inlineCallbacks - def get_joined_rooms_for_user(self, user): - """Returns a list of roomids that the user has any of the given - membership states in.""" - - rooms = yield self.store.get_rooms_for_user( - user.to_string(), - ) - - # For some reason the list of events contains duplicates - # TODO(paul): work out why because I really don't think it should - room_ids = set(r.room_id for r in rooms) - - defer.returnValue(room_ids) - - @defer.inlineCallbacks - def do_3pid_invite( - self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id - ): - invitee = yield self._lookup_3pid( - id_server, medium, address - ) - - if invitee: - handler = self.hs.get_handlers().room_member_handler - yield handler.update_membership( - requester, - UserID.from_string(invitee), - room_id, - "invite", - txn_id=txn_id, - ) - else: - yield self._make_and_store_3pid_invite( - requester, - id_server, - medium, - address, - room_id, - inviter, - txn_id=txn_id - ) - - @defer.inlineCallbacks - def _lookup_3pid(self, id_server, medium, address): - """Looks up a 3pid in the passed identity server. - - Args: - id_server (str): The server name (including port, if required) - of the identity server to use. - medium (str): The type of the third party identifier (e.g. "email"). - address (str): The third party identifier (e.g. "foo@example.com"). - - Returns: - (str) the matrix ID of the 3pid, or None if it is not recognized. - """ - try: - data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), - { - "medium": medium, - "address": address, - } - ) - - if "mxid" in data: - if "signatures" not in data: - raise AuthError(401, "No signatures on 3pid binding") - self.verify_any_signature(data, id_server) - defer.returnValue(data["mxid"]) - - except IOError as e: - logger.warn("Error from identity server lookup: %s" % (e,)) - defer.returnValue(None) - - @defer.inlineCallbacks - def verify_any_signature(self, data, server_hostname): - if server_hostname not in data["signatures"]: - raise AuthError(401, "No signature from server %s" % (server_hostname,)) - for key_name, signature in data["signatures"][server_hostname].items(): - key_data = yield self.hs.get_simple_http_client().get_json( - "%s%s/_matrix/identity/api/v1/pubkey/%s" % - (id_server_scheme, server_hostname, key_name,), - ) - if "public_key" not in key_data: - raise AuthError(401, "No public key named %s from %s" % - (key_name, server_hostname,)) - verify_signed_json( - data, - server_hostname, - decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) - ) - return - - @defer.inlineCallbacks - def _make_and_store_3pid_invite( - self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id - ): - room_state = yield self.hs.get_state_handler().get_current_state(room_id) - - inviter_display_name = "" - inviter_avatar_url = "" - member_event = room_state.get((EventTypes.Member, user.to_string())) - if member_event: - inviter_display_name = member_event.content.get("displayname", "") - inviter_avatar_url = member_event.content.get("avatar_url", "") - - canonical_room_alias = "" - canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) - if canonical_alias_event: - canonical_room_alias = canonical_alias_event.content.get("alias", "") - - room_name = "" - room_name_event = room_state.get((EventTypes.Name, "")) - if room_name_event: - room_name = room_name_event.content.get("name", "") - - room_join_rules = "" - join_rules_event = room_state.get((EventTypes.JoinRules, "")) - if join_rules_event: - room_join_rules = join_rules_event.content.get("join_rule", "") - - room_avatar_url = "" - room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if room_avatar_event: - room_avatar_url = room_avatar_event.content.get("url", "") - - token, public_keys, fallback_public_key, display_name = ( - yield self._ask_id_server_for_third_party_invite( - id_server=id_server, - medium=medium, - address=address, - room_id=room_id, - inviter_user_id=user.to_string(), - room_alias=canonical_room_alias, - room_avatar_url=room_avatar_url, - room_join_rules=room_join_rules, - room_name=room_name, - inviter_display_name=inviter_display_name, - inviter_avatar_url=inviter_avatar_url - ) - ) - - msg_handler = self.hs.get_handlers().message_handler - yield msg_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.ThirdPartyInvite, - "content": { - "display_name": display_name, - "public_keys": public_keys, - - # For backwards compatibility: - "key_validity_url": fallback_public_key["key_validity_url"], - "public_key": fallback_public_key["public_key"], - }, - "room_id": room_id, - "sender": user.to_string(), - "state_key": token, - }, - txn_id=txn_id, - ) - - @defer.inlineCallbacks - def _ask_id_server_for_third_party_invite( - self, - id_server, - medium, - address, - room_id, - inviter_user_id, - room_alias, - room_avatar_url, - room_join_rules, - room_name, - inviter_display_name, - inviter_avatar_url - ): - """ - Asks an identity server for a third party invite. - - :param id_server (str): hostname + optional port for the identity server. - :param medium (str): The literal string "email". - :param address (str): The third party address being invited. - :param room_id (str): The ID of the room to which the user is invited. - :param inviter_user_id (str): The user ID of the inviter. - :param room_alias (str): An alias for the room, for cosmetic - notifications. - :param room_avatar_url (str): The URL of the room's avatar, for cosmetic - notifications. - :param room_join_rules (str): The join rules of the email - (e.g. "public"). - :param room_name (str): The m.room.name of the room. - :param inviter_display_name (str): The current display name of the - inviter. - :param inviter_avatar_url (str): The URL of the inviter's avatar. - - :return: A deferred tuple containing: - token (str): The token which must be signed to prove authenticity. - public_keys ([{"public_key": str, "key_validity_url": str}]): - public_key is a base64-encoded ed25519 public key. - fallback_public_key: One element from public_keys. - display_name (str): A user-friendly name to represent the invited - user. - """ - - is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( - id_server_scheme, id_server, - ) - - invite_config = { - "medium": medium, - "address": address, - "room_id": room_id, - "room_alias": room_alias, - "room_avatar_url": room_avatar_url, - "room_join_rules": room_join_rules, - "room_name": room_name, - "sender": inviter_user_id, - "sender_display_name": inviter_display_name, - "sender_avatar_url": inviter_avatar_url, - } - - if self.hs.config.invite_3pid_guest: - registration_handler = self.hs.get_handlers().registration_handler - guest_access_token = yield registration_handler.guest_access_token_for( - medium=medium, - address=address, - inviter_user_id=inviter_user_id, - ) - - guest_user_info = yield self.hs.get_auth().get_user_by_access_token( - guest_access_token - ) - - invite_config.update({ - "guest_access_token": guest_access_token, - "guest_user_id": guest_user_info["user"].to_string(), - }) - - data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( - is_url, - invite_config - ) - # TODO: Check for success - token = data["token"] - public_keys = data.get("public_keys", []) - if "public_key" in data: - fallback_public_key = { - "public_key": data["public_key"], - "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( - id_server_scheme, id_server, - ), - } - else: - fallback_public_key = public_keys[0] - - if not public_keys: - public_keys.append(fallback_public_key) - display_name = data["display_name"] - defer.returnValue((token, public_keys, fallback_public_key, display_name)) - - def forget(self, user, room_id): - return self.store.forget(user.to_string(), room_id) - - -class RoomListHandler(BaseHandler): - - @defer.inlineCallbacks - def get_public_room_list(self): - room_ids = yield self.store.get_public_room_ids() - - @defer.inlineCallbacks - def handle_room(room_id): - aliases = yield self.store.get_aliases_for_room(room_id) - if not aliases: - defer.returnValue(None) - - state = yield self.state_handler.get_current_state(room_id) - - result = {"aliases": aliases, "room_id": room_id} - - name_event = state.get((EventTypes.Name, ""), None) - if name_event: - name = name_event.content.get("name", None) - if name: - result["name"] = name - - topic_event = state.get((EventTypes.Topic, ""), None) - if topic_event: - topic = topic_event.content.get("topic", None) - if topic: - result["topic"] = topic - - canonical_event = state.get((EventTypes.CanonicalAlias, ""), None) - if canonical_event: - canonical_alias = canonical_event.content.get("alias", None) - if canonical_alias: - result["canonical_alias"] = canonical_alias - - visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None) - visibility = None - if visibility_event: - visibility = visibility_event.content.get("history_visibility", None) - result["world_readable"] = visibility == "world_readable" - - guest_event = state.get((EventTypes.GuestAccess, ""), None) - guest = None - if guest_event: - guest = guest_event.content.get("guest_access", None) - result["guest_can_join"] = guest == "can_join" - - avatar_event = state.get(("m.room.avatar", ""), None) - if avatar_event: - avatar_url = avatar_event.content.get("url", None) - if avatar_url: - result["avatar_url"] = avatar_url - - result["num_joined_members"] = sum( - 1 for (event_type, _), ev in state.items() - if event_type == EventTypes.Member and ev.membership == Membership.JOIN - ) - - defer.returnValue(result) - - result = [] - for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)): - chunk_result = yield defer.gatherResults([ - handle_room(room_id) - for room_id in chunk - ], consumeErrors=True).addErrback(unwrapFirstError) - result.extend(v for v in chunk_result if v) - - # FIXME (erikj): START is no longer a valid value - defer.returnValue({"start": "START", "end": "END", "chunk": result}) - - class RoomContextHandler(BaseHandler): @defer.inlineCallbacks def get_event_context(self, user, room_id, event_id, limit, is_guest): @@ -1040,10 +376,12 @@ class RoomContextHandler(BaseHandler): now_token = yield self.hs.get_event_sources().get_current_token() def filter_evts(events): - return self._filter_events_for_client( + return filter_events_for_client( + self.store, user.to_string(), events, - is_peeking=is_guest) + is_peeking=is_guest + ) event = yield self.store.get_event(event_id, get_prev_content=True, allow_none=True) @@ -1109,7 +447,7 @@ class RoomEventSource(object): logger.warn("Stream has topological part!!!! %r", from_key) from_key = "s%s" % (from_token.stream,) - app_service = yield self.store.get_app_service_by_user_id( + app_service = self.store.get_app_service_by_user_id( user.to_string() ) if app_service: @@ -1147,8 +485,11 @@ class RoomEventSource(object): defer.returnValue((events, end_key)) - def get_current_key(self, direction='f'): - return self.store.get_room_events_max_id(direction) + def get_current_key(self): + return self.store.get_room_events_max_id() + + def get_current_key_for_room(self, room_id): + return self.store.get_room_events_max_id(room_id) @defer.inlineCallbacks def get_pagination_rows(self, user, config, key): diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py new file mode 100644 index 0000000000..b04aea0110 --- /dev/null +++ b/synapse/handlers/room_list.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 - 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import BaseHandler + +from synapse.api.constants import ( + EventTypes, JoinRules, +) +from synapse.util.async import concurrently_execute +from synapse.util.caches.response_cache import ResponseCache + +from collections import namedtuple +from unpaddedbase64 import encode_base64, decode_base64 + +import logging +import msgpack + +logger = logging.getLogger(__name__) + +REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 + + +class RoomListHandler(BaseHandler): + def __init__(self, hs): + super(RoomListHandler, self).__init__(hs) + self.response_cache = ResponseCache(hs) + self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) + + def get_local_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We explicitly don't bother caching searches. + return self._get_public_room_list(limit, since_token, search_filter) + + result = self.response_cache.get((limit, since_token)) + if not result: + result = self.response_cache.set( + (limit, since_token), + self._get_public_room_list(limit, since_token) + ) + return result + + @defer.inlineCallbacks + def _get_public_room_list(self, limit=None, since_token=None, + search_filter=None): + if since_token and since_token != "END": + since_token = RoomListNextBatch.from_token(since_token) + else: + since_token = None + + rooms_to_order_value = {} + rooms_to_num_joined = {} + rooms_to_latest_event_ids = {} + + newly_visible = [] + newly_unpublished = [] + if since_token: + stream_token = since_token.stream_ordering + current_public_id = yield self.store.get_current_public_room_stream_id() + public_room_stream_id = since_token.public_room_stream_id + newly_visible, newly_unpublished = yield self.store.get_public_room_changes( + public_room_stream_id, current_public_id + ) + else: + stream_token = yield self.store.get_room_max_stream_ordering() + public_room_stream_id = yield self.store.get_current_public_room_stream_id() + + room_ids = yield self.store.get_public_room_ids_at_stream_id( + public_room_stream_id + ) + + # We want to return rooms in a particular order: the number of joined + # users. We then arbitrarily use the room_id as a tie breaker. + + @defer.inlineCallbacks + def get_order_for_room(room_id): + latest_event_ids = rooms_to_latest_event_ids.get(room_id, None) + if not latest_event_ids: + latest_event_ids = yield self.store.get_forward_extremeties_for_room( + room_id, stream_token + ) + rooms_to_latest_event_ids[room_id] = latest_event_ids + + if not latest_event_ids: + return + + joined_users = yield self.state_handler.get_current_user_in_room( + room_id, latest_event_ids, + ) + num_joined_users = len(joined_users) + rooms_to_num_joined[room_id] = num_joined_users + + if num_joined_users == 0: + return + + # We want larger rooms to be first, hence negating num_joined_users + rooms_to_order_value[room_id] = (-num_joined_users, room_id) + + yield concurrently_execute(get_order_for_room, room_ids, 10) + + sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) + sorted_rooms = [room_id for room_id, _ in sorted_entries] + + # `sorted_rooms` should now be a list of all public room ids that is + # stable across pagination. Therefore, we can use indices into this + # list as our pagination tokens. + + # Filter out rooms that we don't want to return + rooms_to_scan = [ + r for r in sorted_rooms + if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0 + ] + + total_room_count = len(rooms_to_scan) + + if since_token: + # Filter out rooms we've already returned previously + # `since_token.current_limit` is the index of the last room we + # sent down, so we exclude it and everything before/after it. + if since_token.direction_is_forward: + rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:] + else: + rooms_to_scan = rooms_to_scan[:since_token.current_limit] + rooms_to_scan.reverse() + + # Actually generate the entries. _generate_room_entry will append to + # chunk but will stop if len(chunk) > limit + chunk = [] + if limit and not search_filter: + step = limit + 1 + for i in xrange(0, len(rooms_to_scan), step): + # We iterate here because the vast majority of cases we'll stop + # at first iteration, but occaisonally _generate_room_entry + # won't append to the chunk and so we need to loop again. + # We don't want to scan over the entire range either as that + # would potentially waste a lot of work. + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan[i:i + step], 10 + ) + if len(chunk) >= limit + 1: + break + else: + yield concurrently_execute( + lambda r: self._generate_room_entry( + r, rooms_to_num_joined[r], + chunk, limit, search_filter + ), + rooms_to_scan, 5 + ) + + chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) + + # Work out the new limit of the batch for pagination, or None if we + # know there are no more results that would be returned. + # i.e., [since_token.current_limit..new_limit] is the batch of rooms + # we've returned (or the reverse if we paginated backwards) + # We tried to pull out limit + 1 rooms above, so if we have <= limit + # then we know there are no more results to return + new_limit = None + if chunk and (not limit or len(chunk) > limit): + + if not since_token or since_token.direction_is_forward: + if limit: + chunk = chunk[:limit] + last_room_id = chunk[-1]["room_id"] + else: + if limit: + chunk = chunk[-limit:] + last_room_id = chunk[0]["room_id"] + + new_limit = sorted_rooms.index(last_room_id) + + results = { + "chunk": chunk, + "total_room_count_estimate": total_room_count, + } + + if since_token: + results["new_rooms"] = bool(newly_visible) + + if not since_token or since_token.direction_is_forward: + if new_limit is not None: + results["next_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=True, + ).to_token() + + if since_token: + results["prev_batch"] = since_token.copy_and_replace( + direction_is_forward=False, + current_limit=since_token.current_limit + 1, + ).to_token() + else: + if new_limit is not None: + results["prev_batch"] = RoomListNextBatch( + stream_ordering=stream_token, + public_room_stream_id=public_room_stream_id, + current_limit=new_limit, + direction_is_forward=False, + ).to_token() + + if since_token: + results["next_batch"] = since_token.copy_and_replace( + direction_is_forward=True, + current_limit=since_token.current_limit - 1, + ).to_token() + + defer.returnValue(results) + + @defer.inlineCallbacks + def _generate_room_entry(self, room_id, num_joined_users, chunk, limit, + search_filter): + if limit and len(chunk) > limit + 1: + # We've already got enough, so lets just drop it. + return + + result = { + "room_id": room_id, + "num_joined_members": num_joined_users, + } + + current_state_ids = yield self.state_handler.get_current_state_ids(room_id) + + event_map = yield self.store.get_events([ + event_id for key, event_id in current_state_ids.items() + if key[0] in ( + EventTypes.JoinRules, + EventTypes.Name, + EventTypes.Topic, + EventTypes.CanonicalAlias, + EventTypes.RoomHistoryVisibility, + EventTypes.GuestAccess, + "m.room.avatar", + ) + ]) + + current_state = { + (ev.type, ev.state_key): ev + for ev in event_map.values() + } + + # Double check that this is actually a public room. + join_rules_event = current_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + join_rule = join_rules_event.content.get("join_rule", None) + if join_rule and join_rule != JoinRules.PUBLIC: + defer.returnValue(None) + + aliases = yield self.store.get_aliases_for_room(room_id) + if aliases: + result["aliases"] = aliases + + name_event = yield current_state.get((EventTypes.Name, "")) + if name_event: + name = name_event.content.get("name", None) + if name: + result["name"] = name + + topic_event = current_state.get((EventTypes.Topic, "")) + if topic_event: + topic = topic_event.content.get("topic", None) + if topic: + result["topic"] = topic + + canonical_event = current_state.get((EventTypes.CanonicalAlias, "")) + if canonical_event: + canonical_alias = canonical_event.content.get("alias", None) + if canonical_alias: + result["canonical_alias"] = canonical_alias + + visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, "")) + visibility = None + if visibility_event: + visibility = visibility_event.content.get("history_visibility", None) + result["world_readable"] = visibility == "world_readable" + + guest_event = current_state.get((EventTypes.GuestAccess, "")) + guest = None + if guest_event: + guest = guest_event.content.get("guest_access", None) + result["guest_can_join"] = guest == "can_join" + + avatar_event = current_state.get(("m.room.avatar", "")) + if avatar_event: + avatar_url = avatar_event.content.get("url", None) + if avatar_url: + result["avatar_url"] = avatar_url + + if _matches_room_entry(result, search_filter): + chunk.append(result) + + @defer.inlineCallbacks + def get_remote_public_room_list(self, server_name, limit=None, since_token=None, + search_filter=None): + if search_filter: + # We currently don't support searching across federation, so we have + # to do it manually without pagination + limit = None + since_token = None + + res = yield self._get_remote_list_cached( + server_name, limit=limit, since_token=since_token, + ) + + if search_filter: + res = {"chunk": [ + entry + for entry in list(res.get("chunk", [])) + if _matches_room_entry(entry, search_filter) + ]} + + defer.returnValue(res) + + def _get_remote_list_cached(self, server_name, limit=None, since_token=None, + search_filter=None): + repl_layer = self.hs.get_replication_layer() + if search_filter: + # We can't cache when asking for search + return repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + + result = self.remote_response_cache.get((server_name, limit, since_token)) + if not result: + result = self.remote_response_cache.set( + (server_name, limit, since_token), + repl_layer.get_public_rooms( + server_name, limit=limit, since_token=since_token, + search_filter=search_filter, + ) + ) + return result + + +class RoomListNextBatch(namedtuple("RoomListNextBatch", ( + "stream_ordering", # stream_ordering of the first public room list + "public_room_stream_id", # public room stream id for first public room list + "current_limit", # The number of previous rooms returned + "direction_is_forward", # Bool if this is a next_batch, false if prev_batch +))): + + KEY_DICT = { + "stream_ordering": "s", + "public_room_stream_id": "p", + "current_limit": "n", + "direction_is_forward": "d", + } + + REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} + + @classmethod + def from_token(cls, token): + return RoomListNextBatch(**{ + cls.REVERSE_KEY_DICT[key]: val + for key, val in msgpack.loads(decode_base64(token)).items() + }) + + def to_token(self): + return encode_base64(msgpack.dumps({ + self.KEY_DICT[key]: val + for key, val in self._asdict().items() + })) + + def copy_and_replace(self, **kwds): + return self._replace( + **kwds + ) + + +def _matches_room_entry(room_entry, search_filter): + if search_filter and search_filter.get("generic_search_term", None): + generic_search_term = search_filter["generic_search_term"].upper() + if generic_search_term in room_entry.get("name", "").upper(): + return True + elif generic_search_term in room_entry.get("topic", "").upper(): + return True + elif generic_search_term in room_entry.get("canonical_alias", "").upper(): + return True + else: + return True + + return False diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py new file mode 100644 index 0000000000..ba49075a20 --- /dev/null +++ b/synapse/handlers/room_member.py @@ -0,0 +1,735 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from signedjson.key import decode_verify_key_bytes +from signedjson.sign import verify_signed_json +from twisted.internet import defer +from unpaddedbase64 import decode_base64 + +import synapse.types +from synapse.api.constants import ( + EventTypes, Membership, +) +from synapse.api.errors import AuthError, SynapseError, Codes +from synapse.types import UserID, RoomID +from synapse.util.async import Linearizer +from synapse.util.distributor import user_left_room, user_joined_room +from ._base import BaseHandler + +logger = logging.getLogger(__name__) + +id_server_scheme = "https://" + + +class RoomMemberHandler(BaseHandler): + # TODO(paul): This handler currently contains a messy conflation of + # low-level API that works on UserID objects and so on, and REST-level + # API that takes ID strings and returns pagination chunks. These concerns + # ought to be separated out a lot better. + + def __init__(self, hs): + super(RoomMemberHandler, self).__init__(hs) + + self.member_linearizer = Linearizer() + + self.clock = hs.get_clock() + + self.distributor = hs.get_distributor() + self.distributor.declare("user_joined_room") + self.distributor.declare("user_left_room") + + @defer.inlineCallbacks + def _local_membership_update( + self, requester, target, room_id, membership, + prev_event_ids, + txn_id=None, + ratelimit=True, + content=None, + ): + if content is None: + content = {} + msg_handler = self.hs.get_handlers().message_handler + + content["membership"] = membership + if requester.is_guest: + content["kind"] = "guest" + + event, context = yield msg_handler.create_event( + { + "type": EventTypes.Member, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "state_key": target.to_string(), + + # For backwards compatibility: + "membership": membership, + }, + token_id=requester.access_token_id, + txn_id=txn_id, + prev_event_ids=prev_event_ids, + ) + + # Check if this event matches the previous membership event for the user. + duplicate = yield msg_handler.deduplicate_state_event(event, context) + if duplicate is not None: + # Discard the new event since this membership change is a no-op. + return + + yield msg_handler.handle_new_client_event( + requester, + event, + context, + extra_users=[target], + ratelimit=ratelimit, + ) + + prev_member_event_id = context.prev_state_ids.get( + (EventTypes.Member, target.to_string()), + None + ) + + if event.membership == Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: + yield user_joined_room(self.distributor, target, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target, room_id) + + @defer.inlineCallbacks + def remote_join(self, remote_room_hosts, room_id, user, content): + if len(remote_room_hosts) == 0: + raise SynapseError(404, "No known servers") + + # We don't do an auth check if we are doing an invite + # join dance for now, since we're kinda implicitly checking + # that we are allowed to join when we decide whether or not we + # need to do the invite/join dance. + yield self.hs.get_handlers().federation_handler.do_invite_join( + remote_room_hosts, + room_id, + user.to_string(), + content, + ) + yield user_joined_room(self.distributor, user, room_id) + + def reject_remote_invite(self, user_id, room_id, remote_room_hosts): + return self.hs.get_handlers().federation_handler.do_remotely_reject_invite( + remote_room_hosts, + room_id, + user_id + ) + + @defer.inlineCallbacks + def update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + ): + key = (room_id,) + + with (yield self.member_linearizer.queue(key)): + result = yield self._update_membership( + requester, + target, + room_id, + action, + txn_id=txn_id, + remote_room_hosts=remote_room_hosts, + third_party_signed=third_party_signed, + ratelimit=ratelimit, + content=content, + ) + + defer.returnValue(result) + + @defer.inlineCallbacks + def _update_membership( + self, + requester, + target, + room_id, + action, + txn_id=None, + remote_room_hosts=None, + third_party_signed=None, + ratelimit=True, + content=None, + ): + if content is None: + content = {} + + effective_membership_state = action + if action in ["kick", "unban"]: + effective_membership_state = "leave" + + if third_party_signed is not None: + replication = self.hs.get_replication_layer() + yield replication.exchange_third_party_invite( + third_party_signed["sender"], + target.to_string(), + room_id, + third_party_signed, + ) + + if not remote_room_hosts: + remote_room_hosts = [] + + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + current_state_ids = yield self.state_handler.get_current_state_ids( + room_id, latest_event_ids=latest_event_ids, + ) + + old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) + if old_state_id: + old_state = yield self.store.get_event(old_state_id, allow_none=True) + old_membership = old_state.content.get("membership") if old_state else None + if action == "unban" and old_membership != "ban": + raise SynapseError( + 403, + "Cannot unban user who was not banned" + " (membership=%s)" % old_membership, + errcode=Codes.BAD_STATE + ) + if old_membership == "ban" and action != "unban": + raise SynapseError( + 403, + "Cannot %s user who was banned" % (action,), + errcode=Codes.BAD_STATE + ) + + is_host_in_room = yield self._is_host_in_room(current_state_ids) + + if effective_membership_state == Membership.JOIN: + if requester.is_guest and not self._can_guest_join(current_state_ids): + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + + if not is_host_in_room: + inviter = yield self.get_inviter(target.to_string(), room_id) + if inviter and not self.hs.is_mine(inviter): + remote_room_hosts.append(inviter.domain) + + content["membership"] = Membership.JOIN + + profile = self.hs.get_handlers().profile_handler + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + + if requester.is_guest: + content["kind"] = "guest" + + ret = yield self.remote_join( + remote_room_hosts, room_id, target, content + ) + defer.returnValue(ret) + + elif effective_membership_state == Membership.LEAVE: + if not is_host_in_room: + # perhaps we've been invited + inviter = yield self.get_inviter(target.to_string(), room_id) + if not inviter: + raise SynapseError(404, "Not a known room") + + if self.hs.is_mine(inviter): + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath. + # + # This is a bit of a hack, because the room might still be + # active on other servers. + pass + else: + # send the rejection to the inviter's HS. + remote_room_hosts = remote_room_hosts + [inviter.domain] + + try: + ret = yield self.reject_remote_invite( + target.to_string(), room_id, remote_room_hosts + ) + defer.returnValue(ret) + except SynapseError as e: + logger.warn("Failed to reject invite: %s", e) + + yield self.store.locally_reject_invite( + target.to_string(), room_id + ) + + defer.returnValue({}) + + yield self._local_membership_update( + requester=requester, + target=target, + room_id=room_id, + membership=effective_membership_state, + txn_id=txn_id, + ratelimit=ratelimit, + prev_event_ids=latest_event_ids, + content=content, + ) + + @defer.inlineCallbacks + def send_membership_event( + self, + requester, + event, + context, + remote_room_hosts=None, + ratelimit=True, + ): + """ + Change the membership status of a user in a room. + + Args: + requester (Requester): The local user who requested the membership + event. If None, certain checks, like whether this homeserver can + act as the sender, will be skipped. + event (SynapseEvent): The membership event. + context: The context of the event. + is_guest (bool): Whether the sender is a guest. + room_hosts ([str]): Homeservers which are likely to already be in + the room, and could be danced with in order to join this + homeserver for the first time. + ratelimit (bool): Whether to rate limit this request. + Raises: + SynapseError if there was a problem changing the membership. + """ + remote_room_hosts = remote_room_hosts or [] + + target_user = UserID.from_string(event.state_key) + room_id = event.room_id + + if requester is not None: + sender = UserID.from_string(event.sender) + assert sender == requester.user, ( + "Sender (%s) must be same as requester (%s)" % + (sender, requester.user) + ) + assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) + else: + requester = synapse.types.create_requester(target_user) + + message_handler = self.hs.get_handlers().message_handler + prev_event = yield message_handler.deduplicate_state_event(event, context) + if prev_event is not None: + return + + if event.membership == Membership.JOIN: + if requester.is_guest: + guest_can_join = yield self._can_guest_join(context.prev_state_ids) + if not guest_can_join: + # This should be an auth check, but guests are a local concept, + # so don't really fit into the general auth process. + raise AuthError(403, "Guest access not allowed") + + yield message_handler.handle_new_client_event( + requester, + event, + context, + extra_users=[target_user], + ratelimit=ratelimit, + ) + + prev_member_event_id = context.prev_state_ids.get( + (EventTypes.Member, event.state_key), + None + ) + + if event.membership == Membership.JOIN: + # Only fire user_joined_room if the user has acutally joined the + # room. Don't bother if the user is just changing their profile + # info. + newly_joined = True + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + newly_joined = prev_member_event.membership != Membership.JOIN + if newly_joined: + yield user_joined_room(self.distributor, target_user, room_id) + elif event.membership == Membership.LEAVE: + if prev_member_event_id: + prev_member_event = yield self.store.get_event(prev_member_event_id) + if prev_member_event.membership == Membership.JOIN: + user_left_room(self.distributor, target_user, room_id) + + @defer.inlineCallbacks + def _can_guest_join(self, current_state_ids): + """ + Returns whether a guest can join a room based on its current state. + """ + guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) + if not guest_access_id: + defer.returnValue(False) + + guest_access = yield self.store.get_event(guest_access_id) + + defer.returnValue( + guest_access + and guest_access.content + and "guest_access" in guest_access.content + and guest_access.content["guest_access"] == "can_join" + ) + + @defer.inlineCallbacks + def lookup_room_alias(self, room_alias): + """ + Get the room ID associated with a room alias. + + Args: + room_alias (RoomAlias): The alias to look up. + Returns: + A tuple of: + The room ID as a RoomID object. + Hosts likely to be participating in the room ([str]). + Raises: + SynapseError if room alias could not be found. + """ + directory_handler = self.hs.get_handlers().directory_handler + mapping = yield directory_handler.get_association(room_alias) + + if not mapping: + raise SynapseError(404, "No such room alias") + + room_id = mapping["room_id"] + servers = mapping["servers"] + + defer.returnValue((RoomID.from_string(room_id), servers)) + + @defer.inlineCallbacks + def get_inviter(self, user_id, room_id): + invite = yield self.store.get_invite_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + if invite: + defer.returnValue(UserID.from_string(invite.sender)) + + @defer.inlineCallbacks + def do_3pid_invite( + self, + room_id, + inviter, + medium, + address, + id_server, + requester, + txn_id + ): + invitee = yield self._lookup_3pid( + id_server, medium, address + ) + + if invitee: + yield self.update_membership( + requester, + UserID.from_string(invitee), + room_id, + "invite", + txn_id=txn_id, + ) + else: + yield self._make_and_store_3pid_invite( + requester, + id_server, + medium, + address, + room_id, + inviter, + txn_id=txn_id + ) + + @defer.inlineCallbacks + def _lookup_3pid(self, id_server, medium, address): + """Looks up a 3pid in the passed identity server. + + Args: + id_server (str): The server name (including port, if required) + of the identity server to use. + medium (str): The type of the third party identifier (e.g. "email"). + address (str): The third party identifier (e.g. "foo@example.com"). + + Returns: + str: the matrix ID of the 3pid, or None if it is not recognized. + """ + try: + data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,), + { + "medium": medium, + "address": address, + } + ) + + if "mxid" in data: + if "signatures" not in data: + raise AuthError(401, "No signatures on 3pid binding") + self.verify_any_signature(data, id_server) + defer.returnValue(data["mxid"]) + + except IOError as e: + logger.warn("Error from identity server lookup: %s" % (e,)) + defer.returnValue(None) + + @defer.inlineCallbacks + def verify_any_signature(self, data, server_hostname): + if server_hostname not in data["signatures"]: + raise AuthError(401, "No signature from server %s" % (server_hostname,)) + for key_name, signature in data["signatures"][server_hostname].items(): + key_data = yield self.hs.get_simple_http_client().get_json( + "%s%s/_matrix/identity/api/v1/pubkey/%s" % + (id_server_scheme, server_hostname, key_name,), + ) + if "public_key" not in key_data: + raise AuthError(401, "No public key named %s from %s" % + (key_name, server_hostname,)) + verify_signed_json( + data, + server_hostname, + decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"])) + ) + return + + @defer.inlineCallbacks + def _make_and_store_3pid_invite( + self, + requester, + id_server, + medium, + address, + room_id, + user, + txn_id + ): + room_state = yield self.hs.get_state_handler().get_current_state(room_id) + + inviter_display_name = "" + inviter_avatar_url = "" + member_event = room_state.get((EventTypes.Member, user.to_string())) + if member_event: + inviter_display_name = member_event.content.get("displayname", "") + inviter_avatar_url = member_event.content.get("avatar_url", "") + + canonical_room_alias = "" + canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, "")) + if canonical_alias_event: + canonical_room_alias = canonical_alias_event.content.get("alias", "") + + room_name = "" + room_name_event = room_state.get((EventTypes.Name, "")) + if room_name_event: + room_name = room_name_event.content.get("name", "") + + room_join_rules = "" + join_rules_event = room_state.get((EventTypes.JoinRules, "")) + if join_rules_event: + room_join_rules = join_rules_event.content.get("join_rule", "") + + room_avatar_url = "" + room_avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if room_avatar_event: + room_avatar_url = room_avatar_event.content.get("url", "") + + token, public_keys, fallback_public_key, display_name = ( + yield self._ask_id_server_for_third_party_invite( + id_server=id_server, + medium=medium, + address=address, + room_id=room_id, + inviter_user_id=user.to_string(), + room_alias=canonical_room_alias, + room_avatar_url=room_avatar_url, + room_join_rules=room_join_rules, + room_name=room_name, + inviter_display_name=inviter_display_name, + inviter_avatar_url=inviter_avatar_url + ) + ) + + msg_handler = self.hs.get_handlers().message_handler + yield msg_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.ThirdPartyInvite, + "content": { + "display_name": display_name, + "public_keys": public_keys, + + # For backwards compatibility: + "key_validity_url": fallback_public_key["key_validity_url"], + "public_key": fallback_public_key["public_key"], + }, + "room_id": room_id, + "sender": user.to_string(), + "state_key": token, + }, + txn_id=txn_id, + ) + + @defer.inlineCallbacks + def _ask_id_server_for_third_party_invite( + self, + id_server, + medium, + address, + room_id, + inviter_user_id, + room_alias, + room_avatar_url, + room_join_rules, + room_name, + inviter_display_name, + inviter_avatar_url + ): + """ + Asks an identity server for a third party invite. + + Args: + id_server (str): hostname + optional port for the identity server. + medium (str): The literal string "email". + address (str): The third party address being invited. + room_id (str): The ID of the room to which the user is invited. + inviter_user_id (str): The user ID of the inviter. + room_alias (str): An alias for the room, for cosmetic notifications. + room_avatar_url (str): The URL of the room's avatar, for cosmetic + notifications. + room_join_rules (str): The join rules of the email (e.g. "public"). + room_name (str): The m.room.name of the room. + inviter_display_name (str): The current display name of the + inviter. + inviter_avatar_url (str): The URL of the inviter's avatar. + + Returns: + A deferred tuple containing: + token (str): The token which must be signed to prove authenticity. + public_keys ([{"public_key": str, "key_validity_url": str}]): + public_key is a base64-encoded ed25519 public key. + fallback_public_key: One element from public_keys. + display_name (str): A user-friendly name to represent the invited + user. + """ + + is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( + id_server_scheme, id_server, + ) + + invite_config = { + "medium": medium, + "address": address, + "room_id": room_id, + "room_alias": room_alias, + "room_avatar_url": room_avatar_url, + "room_join_rules": room_join_rules, + "room_name": room_name, + "sender": inviter_user_id, + "sender_display_name": inviter_display_name, + "sender_avatar_url": inviter_avatar_url, + } + + if self.hs.config.invite_3pid_guest: + registration_handler = self.hs.get_handlers().registration_handler + guest_access_token = yield registration_handler.guest_access_token_for( + medium=medium, + address=address, + inviter_user_id=inviter_user_id, + ) + + guest_user_info = yield self.hs.get_auth().get_user_by_access_token( + guest_access_token + ) + + invite_config.update({ + "guest_access_token": guest_access_token, + "guest_user_id": guest_user_info["user"].to_string(), + }) + + data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( + is_url, + invite_config + ) + # TODO: Check for success + token = data["token"] + public_keys = data.get("public_keys", []) + if "public_key" in data: + fallback_public_key = { + "public_key": data["public_key"], + "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % ( + id_server_scheme, id_server, + ), + } + else: + fallback_public_key = public_keys[0] + + if not public_keys: + public_keys.append(fallback_public_key) + display_name = data["display_name"] + defer.returnValue((token, public_keys, fallback_public_key, display_name)) + + @defer.inlineCallbacks + def forget(self, user, room_id): + user_id = user.to_string() + + member = yield self.state_handler.get_current_state( + room_id=room_id, + event_type=EventTypes.Member, + state_key=user_id + ) + membership = member.membership if member else None + + if membership is not None and membership != Membership.LEAVE: + raise SynapseError(400, "User %s in room %s" % ( + user_id, room_id + )) + + if membership: + yield self.store.forget(user_id, room_id) + + @defer.inlineCallbacks + def _is_host_in_room(self, current_state_ids): + # Have we just created the room, and is this about to be the very + # first member event? + create_event_id = current_state_ids.get(("m.room.create", "")) + if len(current_state_ids) == 1 and create_event_id: + defer.returnValue(self.hs.is_mine_id(create_event_id)) + + for (etype, state_key), event_id in current_state_ids.items(): + if etype != EventTypes.Member or not self.hs.is_mine_id(state_key): + continue + + event = yield self.store.get_event(event_id, allow_none=True) + if not event: + continue + + if event.membership == Membership.JOIN: + defer.returnValue(True) + + defer.returnValue(False) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 9937d8dd7f..df75d70fac 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,6 +21,7 @@ from synapse.api.constants import Membership, EventTypes from synapse.api.filtering import Filter from synapse.api.errors import SynapseError from synapse.events.utils import serialize_event +from synapse.visibility import filter_events_for_client from unpaddedbase64 import decode_base64, encode_base64 @@ -172,8 +173,8 @@ class SearchHandler(BaseHandler): filtered_events = search_filter.filter([r["event"] for r in results]) - events = yield self._filter_events_for_client( - user.to_string(), filtered_events + events = yield filter_events_for_client( + self.store, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -223,8 +224,8 @@ class SearchHandler(BaseHandler): r["event"] for r in results ]) - events = yield self._filter_events_for_client( - user.to_string(), filtered_events + events = yield filter_events_for_client( + self.store, user.to_string(), filtered_events ) room_events.extend(events) @@ -281,12 +282,12 @@ class SearchHandler(BaseHandler): event.room_id, event.event_id, before_limit, after_limit ) - res["events_before"] = yield self._filter_events_for_client( - user.to_string(), res["events_before"] + res["events_before"] = yield filter_events_for_client( + self.store, user.to_string(), res["events_before"] ) - res["events_after"] = yield self._filter_events_for_client( - user.to_string(), res["events_after"] + res["events_after"] = yield filter_events_for_client( + self.store, user.to_string(), res["events_after"] ) res["start"] = now_token.copy_and_replace( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 1f6fde8e8a..a86996689c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseHandler - -from synapse.streams.config import PaginationConfig from synapse.api.constants import Membership, EventTypes -from synapse.util import unwrapFirstError -from synapse.util.logcontext import LoggingContext, preserve_fn +from synapse.util.async import concurrently_execute +from synapse.util.logcontext import LoggingContext from synapse.util.metrics import Measure +from synapse.util.caches.response_cache import ResponseCache from synapse.push.clientformat import format_push_rules_for_user +from synapse.visibility import filter_events_for_client from twisted.internet import defer @@ -35,6 +34,8 @@ SyncConfig = collections.namedtuple("SyncConfig", [ "user", "filter_collection", "is_guest", + "request_key", + "device_id", ]) @@ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. + "to_device", # List of direct messages for the device. ])): __slots__ = [] @@ -126,18 +128,22 @@ class SyncResult(collections.namedtuple("SyncResult", [ self.joined or self.invited or self.archived or - self.account_data + self.account_data or + self.to_device ) -class SyncHandler(BaseHandler): +class SyncHandler(object): def __init__(self, hs): - super(SyncHandler, self).__init__(hs) + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + self.presence_handler = hs.get_presence_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() + self.response_cache = ResponseCache(hs) + self.state = hs.get_state_handler() - @defer.inlineCallbacks def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): """Get the sync for a client if we have new data for it now. Otherwise @@ -146,7 +152,19 @@ class SyncHandler(BaseHandler): Returns: A Deferred SyncResult. """ + result = self.response_cache.get(sync_config.request_key) + if not result: + result = self.response_cache.set( + sync_config.request_key, + self._wait_for_sync_for_user( + sync_config, since_token, timeout, full_state + ) + ) + return result + @defer.inlineCallbacks + def _wait_for_sync_for_user(self, sync_config, since_token, timeout, + full_state): context = LoggingContext.current_context() if context: if since_token is None: @@ -179,197 +197,15 @@ class SyncHandler(BaseHandler): Returns: A Deferred SyncResult. """ - if since_token is None or full_state: - return self.full_state_sync(sync_config, since_token) - else: - return self.incremental_sync_with_gap(sync_config, since_token) - - @defer.inlineCallbacks - def full_state_sync(self, sync_config, timeline_since_token): - """Get a sync for a client which is starting without any state. - - If a 'message_since_token' is given, only timeline events which have - happened since that token will be returned. - - Returns: - A Deferred SyncResult. - """ - now_token = yield self.event_sources.get_current_token() - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token - ) - - presence_stream = self.event_sources.sources["presence"] - # TODO (mjark): This looks wrong, shouldn't we be getting the presence - # UP to the present rather than after the present? - pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( - user=sync_config.user, - pagination_config=pagination_config.get_source_config("presence"), - key=None - ) - - membership_list = ( - Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN - ) - - room_list = yield self.store.get_rooms_for_user_where_membership_is( - user_id=sync_config.user.to_string(), - membership_list=membership_list - ) - - account_data, account_data_by_room = ( - yield self.store.get_account_data_for_user( - sync_config.user.to_string() - ) - ) - - account_data['m.push_rules'] = yield self.push_rules_for_user( - sync_config.user - ) - - tags_by_room = yield self.store.get_tags_for_user( - sync_config.user.to_string() - ) - - joined = [] - invited = [] - archived = [] - deferreds = [] - - room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)] - for room_list_chunk in room_list_chunks: - for event in room_list_chunk: - if event.membership == Membership.JOIN: - room_sync_deferred = preserve_fn( - self.full_state_sync_for_joined_room - )( - room_id=event.room_id, - sync_config=sync_config, - now_token=now_token, - timeline_since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(joined.append) - deferreds.append(room_sync_deferred) - elif event.membership == Membership.INVITE: - invite = yield self.store.get_event(event.event_id) - invited.append(InvitedSyncResult( - room_id=event.room_id, - invite=invite, - )) - elif event.membership in (Membership.LEAVE, Membership.BAN): - # Always send down rooms we were banned or kicked from. - if not sync_config.filter_collection.include_leave: - if event.membership == Membership.LEAVE: - if sync_config.user.to_string() == event.sender: - continue - - leave_token = now_token.copy_and_replace( - "room_key", "s%d" % (event.stream_ordering,) - ) - room_sync_deferred = preserve_fn( - self.full_state_sync_for_archived_room - )( - sync_config=sync_config, - room_id=event.room_id, - leave_event_id=event.event_id, - leave_token=leave_token, - timeline_since_token=timeline_since_token, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - ) - room_sync_deferred.addCallback(archived.append) - deferreds.append(room_sync_deferred) - - yield defer.gatherResults( - deferreds, consumeErrors=True - ).addErrback(unwrapFirstError) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - - @defer.inlineCallbacks - def full_state_sync_for_joined_room(self, room_id, sync_config, - now_token, timeline_since_token, - ephemeral_by_room, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred JoinedSyncResult. - """ - - batch = yield self.load_filtered_recents( - room_id, sync_config, now_token, since_token=timeline_since_token - ) - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id, sync_config, - now_token=now_token, - since_token=timeline_since_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=True, - ) - - defer.returnValue(room_sync) + return self.generate_sync_result(sync_config, since_token, full_state) @defer.inlineCallbacks def push_rules_for_user(self, user): user_id = user.to_string() - rawrules = yield self.store.get_push_rules_for_user(user_id) - enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) - rules = format_push_rules_for_user(user, rawrules, enabled_map) + rules = yield self.store.get_push_rules_for_user(user_id) + rules = format_push_rules_for_user(user, rules) defer.returnValue(rules) - def account_data_for_user(self, account_data): - account_data_events = [] - - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - - def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): - account_data_events = [] - tags = tags_by_room.get(room_id) - if tags is not None: - account_data_events.append({ - "type": "m.tag", - "content": {"tags": tags}, - }) - - account_data = account_data_by_room.get(room_id, {}) - for account_data_type, content in account_data.items(): - account_data_events.append({ - "type": account_data_type, - "content": content, - }) - - return account_data_events - @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): """Get the ephemeral events for each room the user is in @@ -432,255 +268,45 @@ class SyncHandler(BaseHandler): defer.returnValue((now_token, ephemeral_by_room)) - def full_state_sync_for_archived_room(self, room_id, sync_config, - leave_event_id, leave_token, - timeline_since_token, tags_by_room, - account_data_by_room): - """Sync a room for a client which is starting without any state - Returns: - A Deferred ArchivedSyncResult. - """ - - return self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event_id, timeline_since_token, tags_by_room, - account_data_by_room, full_state=True, leave_token=leave_token, - ) - @defer.inlineCallbacks - def incremental_sync_with_gap(self, sync_config, since_token): - """ Get the incremental delta needed to bring the client up to - date with the server. - Returns: - A Deferred SyncResult. + def _load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None, recents=None, newly_joined_room=False): """ - now_token = yield self.event_sources.get_current_token() - - rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string()) - room_ids = [room.room_id for room in rooms] - - presence_source = self.event_sources.sources["presence"] - presence, presence_key = yield presence_source.get_new_events( - user=sync_config.user, - from_key=since_token.presence_key, - limit=sync_config.filter_collection.presence_limit(), - room_ids=room_ids, - is_guest=sync_config.is_guest, - ) - now_token = now_token.copy_and_replace("presence_key", presence_key) - - now_token, ephemeral_by_room = yield self.ephemeral_by_room( - sync_config, now_token, since_token - ) - - rm_handler = self.hs.get_handlers().room_member_handler - app_service = yield self.store.get_app_service_by_user_id( - sync_config.user.to_string() - ) - if app_service: - rooms = yield self.store.get_app_service_rooms(app_service) - joined_room_ids = set(r.room_id for r in rooms) - else: - joined_room_ids = yield rm_handler.get_joined_rooms_for_user( - sync_config.user - ) - - user_id = sync_config.user.to_string() - - timeline_limit = sync_config.filter_collection.timeline_limit() - - tags_by_room = yield self.store.get_updated_tags( - user_id, - since_token.account_data_key, - ) - - account_data, account_data_by_room = ( - yield self.store.get_updated_account_data_for_user( - user_id, - since_token.account_data_key, - ) - ) - - push_rules_changed = yield self.store.have_push_rules_changed_for_user( - user_id, int(since_token.push_rules_key) - ) - - if push_rules_changed: - account_data["m.push_rules"] = yield self.push_rules_for_user( - sync_config.user - ) - - # Get a list of membership change events that have happened. - rooms_changed = yield self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key - ) - - mem_change_events_by_room_id = {} - for event in rooms_changed: - mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) - - newly_joined_rooms = [] - archived = [] - invited = [] - for room_id, events in mem_change_events_by_room_id.items(): - non_joins = [e for e in events if e.membership != Membership.JOIN] - has_join = len(non_joins) != len(events) - - # We want to figure out if we joined the room at some point since - # the last sync (even if we have since left). This is to make sure - # we do send down the room, and with full state, where necessary - if room_id in joined_room_ids or has_join: - old_state = yield self.get_state_at(room_id, since_token) - old_mem_ev = old_state.get((EventTypes.Member, user_id), None) - if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: - newly_joined_rooms.append(room_id) - - if room_id in joined_room_ids: - continue - - if not non_joins: - continue - - # Only bother if we're still currently invited - should_invite = non_joins[-1].membership == Membership.INVITE - if should_invite: - room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) - if room_sync: - invited.append(room_sync) - - # Always include leave/ban events. Just take the last one. - # TODO: How do we handle ban -> leave in same batch? - leave_events = [ - e for e in non_joins - if e.membership in (Membership.LEAVE, Membership.BAN) - ] - - if leave_events: - leave_event = leave_events[-1] - room_sync = yield self.incremental_sync_for_archived_room( - sync_config, room_id, leave_event.event_id, since_token, - tags_by_room, account_data_by_room, - full_state=room_id in newly_joined_rooms - ) - if room_sync: - archived.append(room_sync) - - # Get all events for rooms we're currently joined to. - room_to_events = yield self.store.get_room_events_stream_for_rooms( - room_ids=joined_room_ids, - from_key=since_token.room_key, - to_key=now_token.room_key, - limit=timeline_limit + 1, - ) - - joined = [] - # We loop through all room ids, even if there are no new events, in case - # there are non room events taht we need to notify about. - for room_id in joined_room_ids: - room_entry = room_to_events.get(room_id, None) - - if room_entry: - events, start_key = room_entry - - prev_batch_token = now_token.copy_and_replace("room_key", start_key) - - newly_joined_room = room_id in newly_joined_rooms - full_state = newly_joined_room - - batch = yield self.load_filtered_recents( - room_id, sync_config, prev_batch_token, - since_token=since_token, - recents=events, - newly_joined_room=newly_joined_room, - ) - else: - batch = TimelineBatch( - events=[], - prev_batch=since_token, - limited=False, - ) - full_state = False - - room_sync = yield self.incremental_sync_with_gap_for_room( - room_id=room_id, - sync_config=sync_config, - since_token=since_token, - now_token=now_token, - ephemeral_by_room=ephemeral_by_room, - tags_by_room=tags_by_room, - account_data_by_room=account_data_by_room, - batch=batch, - full_state=full_state, - ) - if room_sync: - joined.append(room_sync) - - # For each newly joined room, we want to send down presence of - # existing users. - presence_handler = self.hs.get_handlers().presence_handler - extra_presence_users = set() - for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(event.room_id) - extra_presence_users.update(users) - - # For each new member, send down presence. - for joined_sync in joined: - it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values()) - for event in it: - if event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - extra_presence_users.add(event.state_key) - - states = yield presence_handler.get_states( - [u for u in extra_presence_users if u != user_id], - as_event=True, - ) - presence.extend(states) - - account_data_for_user = sync_config.filter_collection.filter_account_data( - self.account_data_for_user(account_data) - ) - - presence = sync_config.filter_collection.filter_presence( - presence - ) - - defer.returnValue(SyncResult( - presence=presence, - account_data=account_data_for_user, - joined=joined, - invited=invited, - archived=archived, - next_batch=now_token, - )) - - @defer.inlineCallbacks - def load_filtered_recents(self, room_id, sync_config, now_token, - since_token=None, recents=None, newly_joined_room=False): - """ - :returns a Deferred TimelineBatch + Returns: + a Deferred TimelineBatch """ with Measure(self.clock, "load_filtered_recents"): - filtering_factor = 2 timeline_limit = sync_config.filter_collection.timeline_limit() - load_limit = max(timeline_limit * filtering_factor, 10) - max_repeat = 5 # Only try a few times per room, otherwise - room_key = now_token.room_key - end_key = room_key + block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True else: limited = False - if recents is not None: + if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) - recents = yield self._filter_events_for_client( + recents = yield filter_events_for_client( + self.store, sync_config.user.to_string(), recents, ) else: recents = [] + if not limited or block_all_timeline: + defer.returnValue(TimelineBatch( + events=recents, + prev_batch=now_token, + limited=False + )) + + filtering_factor = 2 + load_limit = max(timeline_limit * filtering_factor, 10) + max_repeat = 5 # Only try a few times per room, otherwise + room_key = now_token.room_key + end_key = room_key + since_key = None if since_token and not newly_joined_room: since_key = since_token.room_key @@ -695,7 +321,8 @@ class SyncHandler(BaseHandler): loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) - loaded_recents = yield self._filter_events_for_client( + loaded_recents = yield filter_events_for_client( + self.store, sync_config.user.to_string(), loaded_recents, ) @@ -723,122 +350,32 @@ class SyncHandler(BaseHandler): )) @defer.inlineCallbacks - def incremental_sync_with_gap_for_room(self, room_id, sync_config, - since_token, now_token, - ephemeral_by_room, tags_by_room, - account_data_by_room, - batch, full_state=False): - state = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, now_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - ephemeral = sync_config.filter_collection.filter_room_ephemeral( - ephemeral_by_room.get(room_id, []) - ) - - unread_notifications = {} - room_sync = JoinedSyncResult( - room_id=room_id, - timeline=batch, - state=state, - ephemeral=ephemeral, - account_data=account_data, - unread_notifications=unread_notifications, - ) - - if room_sync: - notifs = yield self.unread_notifs_for_room_id( - room_id, sync_config - ) - - if notifs is not None: - unread_notifications["notification_count"] = notifs["notify_count"] - unread_notifications["highlight_count"] = notifs["highlight_count"] - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks - def incremental_sync_for_archived_room(self, sync_config, room_id, leave_event_id, - since_token, tags_by_room, - account_data_by_room, full_state, - leave_token=None): - """ Get the incremental delta needed to bring the client up to date for - the archived room. - Returns: - A Deferred ArchivedSyncResult - """ - - if not leave_token: - stream_token = yield self.store.get_stream_token_for_event( - leave_event_id - ) - - leave_token = since_token.copy_and_replace("room_key", stream_token) - - if since_token and since_token.is_after(leave_token): - defer.returnValue(None) - - batch = yield self.load_filtered_recents( - room_id, sync_config, leave_token, since_token, - ) - - logger.debug("Recents %r", batch) - - state_events_delta = yield self.compute_state_delta( - room_id, batch, sync_config, since_token, leave_token, - full_state=full_state - ) - - account_data = self.account_data_for_room( - room_id, tags_by_room, account_data_by_room - ) - - account_data = sync_config.filter_collection.filter_room_account_data( - account_data - ) - - room_sync = ArchivedSyncResult( - room_id=room_id, - timeline=batch, - state=state_events_delta, - account_data=account_data, - ) - - logger.debug("Room sync: %r", room_sync) - - defer.returnValue(room_sync) - - @defer.inlineCallbacks def get_state_after_event(self, event): """ Get the room state after the given event - :param synapse.events.EventBase event: event of interest - :return: A Deferred map from ((type, state_key)->Event) + Args: + event(synapse.events.EventBase): event of interest + + Returns: + A Deferred map from ((type, state_key)->Event) """ - state = yield self.store.get_state_for_event(event.event_id) + state_ids = yield self.store.get_state_ids_for_event(event.event_id) if event.is_state(): - state = state.copy() - state[(event.type, event.state_key)] = event - defer.returnValue(state) + state_ids = state_ids.copy() + state_ids[(event.type, event.state_key)] = event.event_id + defer.returnValue(state_ids) @defer.inlineCallbacks def get_state_at(self, room_id, stream_position): """ Get the room state at a particular stream position - :param str room_id: room for which to get state - :param StreamToken stream_position: point at which to get state - :returns: A Deferred map from ((type, state_key)->Event) + + Args: + room_id(str): room for which to get state + stream_position(StreamToken): point at which to get state + + Returns: + A Deferred map from ((type, state_key)->Event) """ last_events, token = yield self.store.get_recent_events_for_room( room_id, end_token=stream_position.room_key, limit=1, @@ -859,15 +396,18 @@ class SyncHandler(BaseHandler): """ Works out the differnce in state between the start of the timeline and the previous sync. - :param str room_id - :param TimelineBatch batch: The timeline batch for the room that will - be sent to the user. - :param sync_config - :param str since_token: Token of the end of the previous batch. May be None. - :param str now_token: Token of the end of the current batch. - :param bool full_state: Whether to force returning the full state. + Args: + room_id(str): + batch(synapse.handlers.sync.TimelineBatch): The timeline batch for + the room that will be sent to the user. + sync_config(synapse.handlers.sync.SyncConfig): + since_token(str|None): Token of the end of the previous batch. May + be None. + now_token(str): Token of the end of the current batch. + full_state(bool): Whether to force returning the full state. - :returns A new event dictionary + Returns: + A deferred new event dictionary """ # TODO(mjark) Check if the state events were received by the server # after the previous sync, since we need to include those state @@ -877,80 +417,66 @@ class SyncHandler(BaseHandler): with Measure(self.clock, "compute_state_delta"): if full_state: if batch: - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state = yield self.store.get_state_for_event( + state_ids = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) else: - current_state = yield self.get_state_at( + current_state_ids = yield self.get_state_at( room_id, stream_position=now_token ) - state = current_state + state_ids = current_state_ids timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, - timeline_start=state, + timeline_start=state_ids, previous={}, - current=current_state, + current=current_state_ids, ) elif batch.limited: state_at_previous_sync = yield self.get_state_at( room_id, stream_position=since_token ) - current_state = yield self.store.get_state_for_event( + current_state_ids = yield self.store.get_state_ids_for_event( batch.events[-1].event_id ) - state_at_timeline_start = yield self.store.get_state_for_event( + state_at_timeline_start = yield self.store.get_state_ids_for_event( batch.events[0].event_id ) timeline_state = { - (event.type, event.state_key): event + (event.type, event.state_key): event.event_id for event in batch.events if event.is_state() } - state = _calculate_state( + state_ids = _calculate_state( timeline_contains=timeline_state, timeline_start=state_at_timeline_start, previous=state_at_previous_sync, - current=current_state, + current=current_state_ids, ) else: - state = {} + state_ids = {} - defer.returnValue({ - (e.type, e.state_key): e - for e in sync_config.filter_collection.filter_room_state(state.values()) - }) - - def check_joined_room(self, sync_config, state_delta): - """ - Check if the user has just joined the given room (so should - be given the full state) + state = {} + if state_ids: + state = yield self.store.get_events(state_ids.values()) - :param sync_config: - :param dict[(str,str), synapse.events.FrozenEvent] state_delta: the - difference in state since the last sync - - :returns A deferred Tuple (state_delta, limited) - """ - join_event = state_delta.get(( - EventTypes.Member, sync_config.user.to_string()), None) - if join_event is not None: - if join_event.content["membership"] == Membership.JOIN: - return True - return False + defer.returnValue({ + (e.type, e.state_key): e + for e in sync_config.filter_collection.filter_room_state(state.values()) + }) @defer.inlineCallbacks def unread_notifs_for_room_id(self, room_id, sync_config): @@ -968,9 +494,613 @@ class SyncHandler(BaseHandler): ) defer.returnValue(notifs) - # There is no new information in this period, so your notification - # count is whatever it was last time. - defer.returnValue(None) + # There is no new information in this period, so your notification + # count is whatever it was last time. + defer.returnValue(None) + + @defer.inlineCallbacks + def generate_sync_result(self, sync_config, since_token=None, full_state=False): + """Generates a sync result. + + Args: + sync_config (SyncConfig) + since_token (StreamToken) + full_state (bool) + + Returns: + Deferred(SyncResult) + """ + + # NB: The now_token gets changed by some of the generate_sync_* methods, + # this is due to some of the underlying streams not supporting the ability + # to query up to a given point. + # Always use the `now_token` in `SyncResultBuilder` + now_token = yield self.event_sources.get_current_token() + + sync_result_builder = SyncResultBuilder( + sync_config, full_state, + since_token=since_token, + now_token=now_token, + ) + + account_data_by_room = yield self._generate_sync_entry_for_account_data( + sync_result_builder + ) + + res = yield self._generate_sync_entry_for_rooms( + sync_result_builder, account_data_by_room + ) + newly_joined_rooms, newly_joined_users = res + + block_all_presence_data = ( + since_token is None and + sync_config.filter_collection.blocks_all_presence() + ) + if not block_all_presence_data: + yield self._generate_sync_entry_for_presence( + sync_result_builder, newly_joined_rooms, newly_joined_users + ) + + yield self._generate_sync_entry_for_to_device(sync_result_builder) + + defer.returnValue(SyncResult( + presence=sync_result_builder.presence, + account_data=sync_result_builder.account_data, + joined=sync_result_builder.joined, + invited=sync_result_builder.invited, + archived=sync_result_builder.archived, + to_device=sync_result_builder.to_device, + next_batch=sync_result_builder.now_token, + )) + + @defer.inlineCallbacks + def _generate_sync_entry_for_to_device(self, sync_result_builder): + """Generates the portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + user_id = sync_result_builder.sync_config.user.to_string() + device_id = sync_result_builder.sync_config.device_id + now_token = sync_result_builder.now_token + since_stream_id = 0 + if sync_result_builder.since_token is not None: + since_stream_id = int(sync_result_builder.since_token.to_device_key) + + if since_stream_id != int(now_token.to_device_key): + # We only delete messages when a new message comes in, but that's + # fine so long as we delete them at some point. + + logger.debug("Deleting messages up to %d", since_stream_id) + yield self.store.delete_messages_for_device( + user_id, device_id, since_stream_id + ) + + logger.debug("Getting messages up to %d", now_token.to_device_key) + messages, stream_id = yield self.store.get_new_messages_for_device( + user_id, device_id, since_stream_id, now_token.to_device_key + ) + logger.debug("Got messages up to %d: %r", stream_id, messages) + sync_result_builder.now_token = now_token.copy_and_replace( + "to_device_key", stream_id + ) + sync_result_builder.to_device = messages + else: + sync_result_builder.to_device = [] + + @defer.inlineCallbacks + def _generate_sync_entry_for_account_data(self, sync_result_builder): + """Generates the account data portion of the sync response. Populates + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + + Returns: + Deferred(dict): A dictionary containing the per room account data. + """ + sync_config = sync_result_builder.sync_config + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + + if since_token and not sync_result_builder.full_state: + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + user_id, + since_token.account_data_key, + ) + ) + + push_rules_changed = yield self.store.have_push_rules_changed_for_user( + user_id, int(since_token.push_rules_key) + ) + + if push_rules_changed: + account_data["m.push_rules"] = yield self.push_rules_for_user( + sync_config.user + ) + else: + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() + ) + ) + + account_data['m.push_rules'] = yield self.push_rules_for_user( + sync_config.user + ) + + account_data_for_user = sync_config.filter_collection.filter_account_data([ + {"type": account_data_type, "content": content} + for account_data_type, content in account_data.items() + ]) + + sync_result_builder.account_data = account_data_for_user + + defer.returnValue(account_data_by_room) + + @defer.inlineCallbacks + def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms, + newly_joined_users): + """Generates the presence portion of the sync response. Populates the + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + newly_joined_rooms(list): List of rooms that the user has joined + since the last sync (or empty if an initial sync) + newly_joined_users(list): List of users that have joined rooms + since the last sync (or empty if an initial sync) + """ + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + user = sync_result_builder.sync_config.user + + presence_source = self.event_sources.sources["presence"] + + since_token = sync_result_builder.since_token + if since_token and not sync_result_builder.full_state: + presence_key = since_token.presence_key + include_offline = True + else: + presence_key = None + include_offline = False + + presence, presence_key = yield presence_source.get_new_events( + user=user, + from_key=presence_key, + is_guest=sync_config.is_guest, + include_offline=include_offline, + ) + sync_result_builder.now_token = now_token.copy_and_replace( + "presence_key", presence_key + ) + + extra_users_ids = set(newly_joined_users) + for room_id in newly_joined_rooms: + users = yield self.state.get_current_user_in_room(room_id) + extra_users_ids.update(users) + extra_users_ids.discard(user.to_string()) + + states = yield self.presence_handler.get_states( + extra_users_ids, + as_event=True, + ) + presence.extend(states) + + # Deduplicate the presence entries so that there's at most one per user + presence = {p["content"]["user_id"]: p for p in presence}.values() + + presence = sync_config.filter_collection.filter_presence( + presence + ) + + sync_result_builder.presence = presence + + @defer.inlineCallbacks + def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room): + """Generates the rooms portion of the sync response. Populates the + `sync_result_builder` with the result. + + Args: + sync_result_builder(SyncResultBuilder) + account_data_by_room(dict): Dictionary of per room account data + + Returns: + Deferred(tuple): Returns a 2-tuple of + `(newly_joined_rooms, newly_joined_users)` + """ + user_id = sync_result_builder.sync_config.user.to_string() + block_all_room_ephemeral = ( + sync_result_builder.since_token is None and + sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral() + ) + + if block_all_room_ephemeral: + ephemeral_by_room = {} + else: + now_token, ephemeral_by_room = yield self.ephemeral_by_room( + sync_result_builder.sync_config, + now_token=sync_result_builder.now_token, + since_token=sync_result_builder.since_token, + ) + sync_result_builder.now_token = now_token + + ignored_account_data = yield self.store.get_global_account_data_by_type_for_user( + "m.ignored_user_list", user_id=user_id, + ) + + if ignored_account_data: + ignored_users = ignored_account_data.get("ignored_users", {}).keys() + else: + ignored_users = frozenset() + + if sync_result_builder.since_token: + res = yield self._get_rooms_changed(sync_result_builder, ignored_users) + room_entries, invited, newly_joined_rooms = res + + tags_by_room = yield self.store.get_updated_tags( + user_id, + sync_result_builder.since_token.account_data_key, + ) + else: + res = yield self._get_all_rooms(sync_result_builder, ignored_users) + room_entries, invited, newly_joined_rooms = res + + tags_by_room = yield self.store.get_tags_for_user(user_id) + + def handle_room_entries(room_entry): + return self._generate_room_entry( + sync_result_builder, + ignored_users, + room_entry, + ephemeral=ephemeral_by_room.get(room_entry.room_id, []), + tags=tags_by_room.get(room_entry.room_id), + account_data=account_data_by_room.get(room_entry.room_id, {}), + always_include=sync_result_builder.full_state, + ) + + yield concurrently_execute(handle_room_entries, room_entries, 10) + + sync_result_builder.invited.extend(invited) + + # Now we want to get any newly joined users + newly_joined_users = set() + if sync_result_builder.since_token: + for joined_sync in sync_result_builder.joined: + it = itertools.chain( + joined_sync.timeline.events, joined_sync.state.values() + ) + for event in it: + if event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + newly_joined_users.add(event.state_key) + + defer.returnValue((newly_joined_rooms, newly_joined_users)) + + @defer.inlineCallbacks + def _get_rooms_changed(self, sync_result_builder, ignored_users): + """Gets the the changes that have happened since the last sync. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)` + """ + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + assert since_token + + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service: + rooms = yield self.store.get_app_service_rooms(app_service) + joined_room_ids = set(r.room_id for r in rooms) + else: + rooms = yield self.store.get_rooms_for_user(user_id) + joined_room_ids = set(r.room_id for r in rooms) + + # Get a list of membership change events that have happened. + rooms_changed = yield self.store.get_membership_changes_for_user( + user_id, since_token.room_key, now_token.room_key + ) + + mem_change_events_by_room_id = {} + for event in rooms_changed: + mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) + + newly_joined_rooms = [] + room_entries = [] + invited = [] + for room_id, events in mem_change_events_by_room_id.items(): + non_joins = [e for e in events if e.membership != Membership.JOIN] + has_join = len(non_joins) != len(events) + + # We want to figure out if we joined the room at some point since + # the last sync (even if we have since left). This is to make sure + # we do send down the room, and with full state, where necessary + if room_id in joined_room_ids or has_join: + old_state_ids = yield self.get_state_at(room_id, since_token) + old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) + old_mem_ev = None + if old_mem_ev_id: + old_mem_ev = yield self.store.get_event( + old_mem_ev_id, allow_none=True + ) + if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: + newly_joined_rooms.append(room_id) + + if room_id in joined_room_ids: + continue + + if not non_joins: + continue + + # Only bother if we're still currently invited + should_invite = non_joins[-1].membership == Membership.INVITE + if should_invite: + if event.sender not in ignored_users: + room_sync = InvitedSyncResult(room_id, invite=non_joins[-1]) + if room_sync: + invited.append(room_sync) + + # Always include leave/ban events. Just take the last one. + # TODO: How do we handle ban -> leave in same batch? + leave_events = [ + e for e in non_joins + if e.membership in (Membership.LEAVE, Membership.BAN) + ] + + if leave_events: + leave_event = leave_events[-1] + leave_stream_token = yield self.store.get_stream_token_for_event( + leave_event.event_id + ) + leave_token = since_token.copy_and_replace( + "room_key", leave_stream_token + ) + + if since_token and since_token.is_after(leave_token): + continue + + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="archived", + events=None, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=leave_token, + )) + + timeline_limit = sync_config.filter_collection.timeline_limit() + + # Get all events for rooms we're currently joined to. + room_to_events = yield self.store.get_room_events_stream_for_rooms( + room_ids=joined_room_ids, + from_key=since_token.room_key, + to_key=now_token.room_key, + limit=timeline_limit + 1, + ) + + # We loop through all room ids, even if there are no new events, in case + # there are non room events taht we need to notify about. + for room_id in joined_room_ids: + room_entry = room_to_events.get(room_id, None) + + if room_entry: + events, start_key = room_entry + + prev_batch_token = now_token.copy_and_replace("room_key", start_key) + + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="joined", + events=events, + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=None if room_id in newly_joined_rooms else since_token, + upto_token=prev_batch_token, + )) + else: + room_entries.append(RoomSyncResultBuilder( + room_id=room_id, + rtype="joined", + events=[], + newly_joined=room_id in newly_joined_rooms, + full_state=False, + since_token=since_token, + upto_token=since_token, + )) + + defer.returnValue((room_entries, invited, newly_joined_rooms)) + + @defer.inlineCallbacks + def _get_all_rooms(self, sync_result_builder, ignored_users): + """Returns entries for all rooms for the user. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + + Returns: + Deferred(tuple): Returns a tuple of the form: + `([RoomSyncResultBuilder], [InvitedSyncResult], [])` + """ + + user_id = sync_result_builder.sync_config.user.to_string() + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + membership_list = ( + Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN + ) + + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=user_id, + membership_list=membership_list + ) + + room_entries = [] + invited = [] + + for event in room_list: + if event.membership == Membership.JOIN: + room_entries.append(RoomSyncResultBuilder( + room_id=event.room_id, + rtype="joined", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=now_token, + )) + elif event.membership == Membership.INVITE: + if event.sender in ignored_users: + continue + invite = yield self.store.get_event(event.event_id) + invited.append(InvitedSyncResult( + room_id=event.room_id, + invite=invite, + )) + elif event.membership in (Membership.LEAVE, Membership.BAN): + # Always send down rooms we were banned or kicked from. + if not sync_config.filter_collection.include_leave: + if event.membership == Membership.LEAVE: + if user_id == event.sender: + continue + + leave_token = now_token.copy_and_replace( + "room_key", "s%d" % (event.stream_ordering,) + ) + room_entries.append(RoomSyncResultBuilder( + room_id=event.room_id, + rtype="archived", + events=None, + newly_joined=False, + full_state=True, + since_token=since_token, + upto_token=leave_token, + )) + + defer.returnValue((room_entries, invited, [])) + + @defer.inlineCallbacks + def _generate_room_entry(self, sync_result_builder, ignored_users, + room_builder, ephemeral, tags, account_data, + always_include=False): + """Populates the `joined` and `archived` section of `sync_result_builder` + based on the `room_builder`. + + Args: + sync_result_builder(SyncResultBuilder) + ignored_users(set(str)): Set of users ignored by user. + room_builder(RoomSyncResultBuilder) + ephemeral(list): List of new ephemeral events for room + tags(list): List of *all* tags for room, or None if there has been + no change. + account_data(list): List of new account data for room + always_include(bool): Always include this room in the sync response, + even if empty. + """ + newly_joined = room_builder.newly_joined + full_state = ( + room_builder.full_state + or newly_joined + or sync_result_builder.full_state + ) + events = room_builder.events + + # We want to shortcut out as early as possible. + if not (always_include or account_data or ephemeral or full_state): + if events == [] and tags is None: + return + + since_token = sync_result_builder.since_token + now_token = sync_result_builder.now_token + sync_config = sync_result_builder.sync_config + + room_id = room_builder.room_id + since_token = room_builder.since_token + upto_token = room_builder.upto_token + + batch = yield self._load_filtered_recents( + room_id, sync_config, + now_token=upto_token, + since_token=since_token, + recents=events, + newly_joined_room=newly_joined, + ) + + account_data_events = [] + if tags is not None: + account_data_events.append({ + "type": "m.tag", + "content": {"tags": tags}, + }) + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + account_data = sync_config.filter_collection.filter_room_account_data( + account_data_events + ) + + ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) + + if not (always_include or batch or account_data or ephemeral or full_state): + return + + state = yield self.compute_state_delta( + room_id, batch, sync_config, since_token, now_token, + full_state=full_state + ) + + if room_builder.rtype == "joined": + unread_notifications = {} + room_sync = JoinedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + ephemeral=ephemeral, + account_data=account_data_events, + unread_notifications=unread_notifications, + ) + + if room_sync or always_include: + notifs = yield self.unread_notifs_for_room_id( + room_id, sync_config + ) + + if notifs is not None: + unread_notifications["notification_count"] = notifs["notify_count"] + unread_notifications["highlight_count"] = notifs["highlight_count"] + + sync_result_builder.joined.append(room_sync) + elif room_builder.rtype == "archived": + room_sync = ArchivedSyncResult( + room_id=room_id, + timeline=batch, + state=state, + account_data=account_data, + ) + if room_sync or always_include: + sync_result_builder.archived.append(room_sync) + else: + raise Exception("Unrecognized rtype: %r", room_builder.rtype) def _action_has_highlight(actions): @@ -997,25 +1127,72 @@ def _calculate_state(timeline_contains, timeline_start, previous, current): Returns: dict """ - event_id_to_state = { - e.event_id: e - for e in itertools.chain( - timeline_contains.values(), - previous.values(), - timeline_start.values(), - current.values(), + event_id_to_key = { + e: key + for key, e in itertools.chain( + timeline_contains.items(), + previous.items(), + timeline_start.items(), + current.items(), ) } - c_ids = set(e.event_id for e in current.values()) - tc_ids = set(e.event_id for e in timeline_contains.values()) - p_ids = set(e.event_id for e in previous.values()) - ts_ids = set(e.event_id for e in timeline_start.values()) + c_ids = set(e for e in current.values()) + tc_ids = set(e for e in timeline_contains.values()) + p_ids = set(e for e in previous.values()) + ts_ids = set(e for e in timeline_start.values()) state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids - evs = (event_id_to_state[e] for e in state_ids) return { - (e.type, e.state_key): e - for e in evs + event_id_to_key[e]: e for e in state_ids } + + +class SyncResultBuilder(object): + "Used to help build up a new SyncResult for a user" + def __init__(self, sync_config, full_state, since_token, now_token): + """ + Args: + sync_config(SyncConfig) + full_state(bool): The full_state flag as specified by user + since_token(StreamToken): The token supplied by user, or None. + now_token(StreamToken): The token to sync up to. + """ + self.sync_config = sync_config + self.full_state = full_state + self.since_token = since_token + self.now_token = now_token + + self.presence = [] + self.account_data = [] + self.joined = [] + self.invited = [] + self.archived = [] + self.device = [] + + +class RoomSyncResultBuilder(object): + """Stores information needed to create either a `JoinedSyncResult` or + `ArchivedSyncResult`. + """ + def __init__(self, room_id, rtype, events, newly_joined, full_state, + since_token, upto_token): + """ + Args: + room_id(str) + rtype(str): One of `"joined"` or `"archived"` + events(list): List of events to include in the room, (more events + may be added when generating result). + newly_joined(bool): If the user has newly joined the room + full_state(bool): Whether the full state should be sent in result + since_token(StreamToken): Earliest point to return events from, or None + upto_token(StreamToken): Latest point to return events from. + """ + self.room_id = room_id + self.rtype = rtype + self.events = events + self.newly_joined = newly_joined + self.full_state = full_state + self.since_token = since_token + self.upto_token = upto_token diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 8ce27f49ec..0eea7f8f9c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,12 +15,11 @@ from twisted.internet import defer -from ._base import BaseHandler - from synapse.api.errors import SynapseError, AuthError -from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.logcontext import preserve_fn from synapse.util.metrics import Measure -from synapse.types import UserID +from synapse.util.wheel_timer import WheelTimer +from synapse.types import UserID, get_domain_from_id import logging @@ -32,25 +31,38 @@ logger = logging.getLogger(__name__) # A tiny object useful for storing a user's membership in a room, as a mapping # key -RoomMember = namedtuple("RoomMember", ("room_id", "user")) +RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) + + +# How often we expect remote servers to resend us presence. +FEDERATION_TIMEOUT = 60 * 1000 + +# How often to resend typing across federation. +FEDERATION_PING_INTERVAL = 40 * 1000 -class TypingNotificationHandler(BaseHandler): +class TypingHandler(object): def __init__(self, hs): - super(TypingNotificationHandler, self).__init__(hs) + self.store = hs.get_datastore() + self.server_name = hs.config.server_name + self.auth = hs.get_auth() + self.is_mine_id = hs.is_mine_id + self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() - self.homeserver = hs + self.hs = hs self.clock = hs.get_clock() + self.wheel_timer = WheelTimer(bucket_size=5000) - self.federation = hs.get_replication_layer() + self.federation = hs.get_federation_sender() - self.federation.register_edu_handler("m.typing", self._recv_edu) + hs.get_replication_layer().register_edu_handler("m.typing", self._recv_edu) hs.get_distributor().observe("user_left_room", self.user_left_room) self._member_typing_until = {} # clock time we expect to stop - self._member_typing_timer = {} # deferreds to manage theabove + self._member_last_federation_poke = {} # map room IDs to serial numbers self._room_serials = {} @@ -58,44 +70,78 @@ class TypingNotificationHandler(BaseHandler): # map room IDs to sets of users currently typing self._room_typing = {} - def tearDown(self): - """Cancels all the pending timers. - Normally this shouldn't be needed, but it's required from unit tests - to avoid a "Reactor was unclean" warning.""" - for t in self._member_typing_timer.values(): - self.clock.cancel_call_later(t) + self.clock.looping_call( + self._handle_timeouts, + 5000, + ) + + def _handle_timeouts(self): + logger.info("Checking for typing timeouts") + + now = self.clock.time_msec() + + members = set(self.wheel_timer.fetch(now)) + + for member in members: + if not self.is_typing(member): + # Nothing to do if they're no longer typing + continue + + until = self._member_typing_until.get(member, None) + if not until or until <= now: + logger.info("Timing out typing for: %s", member.user_id) + preserve_fn(self._stopped_typing)(member) + continue + + # Check if we need to resend a keep alive over federation for this + # user. + if self.hs.is_mine_id(member.user_id): + last_fed_poke = self._member_last_federation_poke.get(member, None) + if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: + preserve_fn(self._push_remote)( + member=member, + typing=True + ) + + # Add a paranoia timer to ensure that we always have a timer for + # each person typing. + self.wheel_timer.insert( + now=now, + obj=member, + then=now + 60 * 1000, + ) + + def is_typing(self, member): + return member.user_id in self._room_typing.get(member.room_id, []) @defer.inlineCallbacks def started_typing(self, target_user, auth_user, room_id, timeout): - if not self.hs.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has started typing in %s", target_user.to_string(), room_id + "%s has started typing in %s", target_user_id, room_id ) - until = self.clock.time_msec() + timeout - member = RoomMember(room_id=room_id, user=target_user) + member = RoomMember(room_id=room_id, user_id=target_user_id) - was_present = member in self._member_typing_until + was_present = member.user_id in self._room_typing.get(room_id, set()) - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - - def _cb(): - logger.debug( - "%s has timed out in %s", target_user.to_string(), room_id - ) - self._stopped_typing(member) + now = self.clock.time_msec() + self._member_typing_until[member] = now + timeout - self._member_typing_until[member] = until - self._member_typing_timer[member] = self.clock.call_later( - timeout / 1000.0, _cb + self.wheel_timer.insert( + now=now, + obj=member, + then=now + timeout, ) if was_present: @@ -103,132 +149,146 @@ class TypingNotificationHandler(BaseHandler): defer.returnValue(None) yield self._push_update( - room_id=room_id, - user=target_user, + member=member, typing=True, ) @defer.inlineCallbacks def stopped_typing(self, target_user, auth_user, room_id): - if not self.hs.is_mine(target_user): + target_user_id = target_user.to_string() + auth_user_id = auth_user.to_string() + + if not self.is_mine_id(target_user_id): raise SynapseError(400, "User is not hosted on this Home Server") - if target_user != auth_user: + if target_user_id != auth_user_id: raise AuthError(400, "Cannot set another user's typing state") - yield self.auth.check_joined_room(room_id, target_user.to_string()) + yield self.auth.check_joined_room(room_id, target_user_id) logger.debug( - "%s has stopped typing in %s", target_user.to_string(), room_id + "%s has stopped typing in %s", target_user_id, room_id ) - member = RoomMember(room_id=room_id, user=target_user) - - if member in self._member_typing_timer: - self.clock.cancel_call_later(self._member_typing_timer[member]) - del self._member_typing_timer[member] + member = RoomMember(room_id=room_id, user_id=target_user_id) yield self._stopped_typing(member) @defer.inlineCallbacks def user_left_room(self, user, room_id): - if self.hs.is_mine(user): - member = RoomMember(room_id=room_id, user=user) + user_id = user.to_string() + if self.is_mine_id(user_id): + member = RoomMember(room_id=room_id, user_id=user_id) yield self._stopped_typing(member) @defer.inlineCallbacks def _stopped_typing(self, member): - if member not in self._member_typing_until: + if member.user_id not in self._room_typing.get(member.room_id, set()): # No point defer.returnValue(None) + self._member_typing_until.pop(member, None) + self._member_last_federation_poke.pop(member, None) + yield self._push_update( - room_id=member.room_id, - user=member.user, + member=member, typing=False, ) - del self._member_typing_until[member] - - if member in self._member_typing_timer: - # Don't cancel it - either it already expired, or the real - # stopped_typing() will cancel it - del self._member_typing_timer[member] - @defer.inlineCallbacks - def _push_update(self, room_id, user, typing): - localusers = set() - remotedomains = set() - - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers, remotedomains=remotedomains + def _push_update(self, member, typing): + if self.hs.is_mine_id(member.user_id): + # Only send updates for changes to our own users. + yield self._push_remote(member, typing) + + self._push_update_local( + member=member, + typing=typing ) - if localusers: - self._push_update_local( - room_id=room_id, - user=user, - typing=typing - ) - - deferreds = [] - for domain in remotedomains: - deferreds.append(self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": room_id, - "user_id": user.to_string(), - "typing": typing, - }, - )) + @defer.inlineCallbacks + def _push_remote(self, member, typing): + users = yield self.state.get_current_user_in_room(member.room_id) + self._member_last_federation_poke[member] = self.clock.time_msec() + + now = self.clock.time_msec() + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_PING_INTERVAL, + ) - yield defer.DeferredList(deferreds, consumeErrors=True) + for domain in set(get_domain_from_id(u) for u in users): + if domain != self.server_name: + self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": member.room_id, + "user_id": member.user_id, + "typing": typing, + }, + key=member, + ) @defer.inlineCallbacks def _recv_edu(self, origin, content): room_id = content["room_id"] - user = UserID.from_string(content["user_id"]) + user_id = content["user_id"] - localusers = set() + member = RoomMember(user_id=user_id, room_id=room_id) - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers - ) + # Check that the string is a valid user id + user = UserID.from_string(user_id) - if localusers: + if user.domain != origin: + logger.info( + "Got typing update from %r with bad 'user_id': %r", + origin, user_id, + ) + return + + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) + + if self.server_name in domains: + logger.info("Got typing update from %s: %r", user_id, content) + now = self.clock.time_msec() + self._member_typing_until[member] = now + FEDERATION_TIMEOUT + self.wheel_timer.insert( + now=now, + obj=member, + then=now + FEDERATION_TIMEOUT, + ) self._push_update_local( - room_id=room_id, - user=user, + member=member, typing=content["typing"] ) - def _push_update_local(self, room_id, user, typing): - room_set = self._room_typing.setdefault(room_id, set()) + def _push_update_local(self, member, typing): + room_set = self._room_typing.setdefault(member.room_id, set()) if typing: - room_set.add(user) + room_set.add(member.user_id) else: - room_set.discard(user) + room_set.discard(member.user_id) self._latest_room_serial += 1 - self._room_serials[room_id] = self._latest_room_serial + self._room_serials[member.room_id] = self._latest_room_serial - with PreserveLoggingContext(): - self.notifier.on_new_event( - "typing_key", self._latest_room_serial, rooms=[room_id] - ) + self.notifier.on_new_event( + "typing_key", self._latest_room_serial, rooms=[member.room_id] + ) def get_all_typing_updates(self, last_id, current_id): # TODO: Work out a way to do this without scanning the entire state. + if last_id == current_id: + return [] + rows = [] for room_id, serial in self._room_serials.items(): if last_id < serial and serial <= current_id: typing = self._room_typing[room_id] - typing_bytes = json.dumps([ - u.to_string() for u in typing - ], ensure_ascii=False) + typing_bytes = json.dumps(list(typing), ensure_ascii=False) rows.append((serial, room_id, typing_bytes)) rows.sort() return rows @@ -238,34 +298,26 @@ class TypingNotificationEventSource(object): def __init__(self, hs): self.hs = hs self.clock = hs.get_clock() - self._handler = None - self._room_member_handler = None - - def handler(self): - # Avoid cyclic dependency in handler setup - if not self._handler: - self._handler = self.hs.get_handlers().typing_notification_handler - return self._handler - - def room_member_handler(self): - if not self._room_member_handler: - self._room_member_handler = self.hs.get_handlers().room_member_handler - return self._room_member_handler + # We can't call get_typing_handler here because there's a cycle: + # + # Typing -> Notifier -> TypingNotificationEventSource -> Typing + # + self.get_typing_handler = hs.get_typing_handler def _make_event_for(self, room_id): - typing = self.handler()._room_typing[room_id] + typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", "room_id": room_id, "content": { - "user_ids": [u.to_string() for u in typing], + "user_ids": list(typing), }, } def get_new_events(self, from_key, room_ids, **kwargs): with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) - handler = self.handler() + handler = self.get_typing_handler() events = [] for room_id in room_ids: @@ -279,7 +331,7 @@ class TypingNotificationEventSource(object): return events, handler._latest_room_serial def get_current_key(self): - return self.handler()._latest_room_serial + return self.get_typing_handler()._latest_room_serial def get_pagination_rows(self, user, pagination_config, key): return ([], pagination_config.from_key) |